Skip to content

Commit

Permalink
Add API to get managed identity source
Browse files Browse the repository at this point in the history
  • Loading branch information
neha-bhargava committed May 29, 2024
1 parent fb02867 commit 9a07472
Show file tree
Hide file tree
Showing 12 changed files with 124 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ protected String getMessageFromErrorResponse(IHttpResponse response) {
managedIdentityErrorResponse.getError(), managedIdentityErrorResponse.getErrorDescription());
}

protected static IEnvironmentVariables getEnvironmentVariables(ManagedIdentityParameters parameters) {
return parameters.environmentVariables == null ? new EnvironmentVariables() : parameters.environmentVariables;
protected static IEnvironmentVariables getEnvironmentVariables() {
return ManagedIdentityApplication.environmentVariables == null ?
new EnvironmentVariables() : ManagedIdentityApplication.environmentVariables;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ private AppServiceManagedIdentitySource(MsalRequest msalRequest, ServiceBundle s

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {

IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
IEnvironmentVariables environmentVariables = getEnvironmentVariables();
String msiSecret = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_HEADER);
String msiEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class AzureArcManagedIdentitySource extends AbstractManagedIdentitySource{

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle)
{
IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
IEnvironmentVariables environmentVariables = getEnvironmentVariables();
String identityEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
String imdsEndpoint = environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private CloudShellManagedIdentitySource(MsalRequest msalRequest, ServiceBundle s

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {

IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
IEnvironmentVariables environmentVariables = getEnvironmentVariables();
String msiEndpoint = environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT);


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ public IMDSManagedIdentitySource(MsalRequest msalRequest,
ServiceBundle serviceBundle) {
super(msalRequest, serviceBundle, ManagedIdentitySourceType.IMDS);
ManagedIdentityParameters parameters = (ManagedIdentityParameters) msalRequest.requestContext().apiParameters();
IEnvironmentVariables environmentVariables = ((ManagedIdentityParameters) msalRequest.requestContext().apiParameters()).environmentVariables == null ?
new EnvironmentVariables() :
parameters.environmentVariables;
IEnvironmentVariables environmentVariables = getEnvironmentVariables();
if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.AZURE_POD_IDENTITY_AUTHORITY_HOST))){
LOG.info(String.format("[Managed Identity] Environment variable AZURE_POD_IDENTITY_AUTHORITY_HOST for IMDS returned endpoint: %s", environmentVariables.getEnvironmentVariable(Constants.AZURE_POD_IDENTITY_AUTHORITY_HOST)));
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

package com.microsoft.aad.msal4j;

import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import org.slf4j.LoggerFactory;

import java.util.concurrent.CompletableFuture;
Expand All @@ -22,6 +24,16 @@ public class ManagedIdentityApplication extends AbstractApplicationBase implemen
@Getter
static TokenCache sharedTokenCache = new TokenCache();

@Getter(value = AccessLevel.PUBLIC)
static ManagedIdentitySourceType managedIdentitySource = ManagedIdentityClient.getManagedIdentitySource();

@Getter(value = AccessLevel.PACKAGE)
static IEnvironmentVariables environmentVariables;

static void setEnvironmentVariables(IEnvironmentVariables environmentVariables) {
ManagedIdentityApplication.environmentVariables = environmentVariables;
}

private ManagedIdentityApplication(Builder builder) {
super(builder);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

package com.microsoft.aad.msal4j;

import lombok.AccessLevel;
import lombok.Getter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -12,6 +14,37 @@
class ManagedIdentityClient {
private static final Logger LOG = LoggerFactory.getLogger(ManagedIdentityClient.class);

private static ManagedIdentitySourceType managedIdentitySourceType;

protected static void resetManagedIdentitySourceType() {
managedIdentitySourceType = ManagedIdentitySourceType.NONE;
}

static ManagedIdentitySourceType getManagedIdentitySource() {
if (managedIdentitySourceType != null && managedIdentitySourceType != ManagedIdentitySourceType.NONE) {
return managedIdentitySourceType;
}

IEnvironmentVariables environmentVariables = AbstractManagedIdentitySource.getEnvironmentVariables();

if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT)) &&
!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_HEADER))) {
if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_SERVER_THUMBPRINT))) {
managedIdentitySourceType = ManagedIdentitySourceType.SERVICE_FABRIC;
} else
managedIdentitySourceType = ManagedIdentitySourceType.APP_SERVICE;
} else if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT))) {
managedIdentitySourceType = ManagedIdentitySourceType.CLOUD_SHELL;
} else if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT)) &&
!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT))) {
managedIdentitySourceType = ManagedIdentitySourceType.AZURE_ARC;
} else {
managedIdentitySourceType = ManagedIdentitySourceType.DEFAULT_TO_IMDS;
}

return managedIdentitySourceType;
}

AbstractManagedIdentitySource managedIdentitySource;

ManagedIdentityClient(MsalRequest msalRequest, ServiceBundle serviceBundle) {
Expand All @@ -38,16 +71,22 @@ ManagedIdentityResponse getManagedIdentityResponse(ManagedIdentityParameters par
private static AbstractManagedIdentitySource createManagedIdentitySource(MsalRequest msalRequest,
ServiceBundle serviceBundle) {
AbstractManagedIdentitySource managedIdentitySource;
if ((managedIdentitySource = ServiceFabricManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else if ((managedIdentitySource = AppServiceManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else if ((managedIdentitySource = CloudShellManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else if ((managedIdentitySource = AzureArcManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else {
return new IMDSManagedIdentitySource(msalRequest, serviceBundle);

if (managedIdentitySourceType == null || managedIdentitySourceType == ManagedIdentitySourceType.NONE) {
managedIdentitySourceType = getManagedIdentitySource();
}

switch (managedIdentitySourceType) {
case SERVICE_FABRIC:
return ServiceFabricManagedIdentitySource.create(msalRequest, serviceBundle);
case APP_SERVICE:
return AppServiceManagedIdentitySource.create(msalRequest, serviceBundle);
case CLOUD_SHELL:
return CloudShellManagedIdentitySource.create(msalRequest, serviceBundle);
case AZURE_ARC:
return AzureArcManagedIdentitySource.create(msalRequest, serviceBundle);
default:
return new IMDSManagedIdentitySource(msalRequest, serviceBundle);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ public class ManagedIdentityParameters implements IAcquireTokenParameters {

boolean forceRefresh;

IEnvironmentVariables environmentVariables;

@Override
public Set<String> scopes() {
return null;
Expand All @@ -54,10 +52,6 @@ public Map<String, String> extraQueryParameters() {
return null;
}

void setEnvironmentVariablesConfig(IEnvironmentVariables environmentVariables) {
this.environmentVariables = environmentVariables;
}

private static ManagedIdentityParametersBuilder builder() {
return new ManagedIdentityParametersBuilder();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
enum ManagedIdentitySourceType {
// Default.
NONE,
// The source to acquire token for managed identity is IMDS.
// The source used to acquire token for managed identity is IMDS.
IMDS,
// The source to acquire token for managed identity is App Service.
// The source used to acquire token for managed identity is App Service.
APP_SERVICE,
// The source to acquire token for managed identity is Azure Arc.
// The source used to acquire token for managed identity is Azure Arc.
AZURE_ARC,
// The source to acquire token for managed identity is Cloud Shell.
// The source used to acquire token for managed identity is Cloud Shell.
CLOUD_SHELL,
// The source to acquire token for managed identity is Service Fabric.
SERVICE_FABRIC
// The source used to acquire token for managed identity is Service Fabric.
SERVICE_FABRIC,
// The source to acquire token for managed identity is defaulted to IMDS when no environment variables are set.
DEFAULT_TO_IMDS
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,11 @@ public ManagedIdentityResponse getManagedIdentityResponse(

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {

IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
IEnvironmentVariables environmentVariables = getEnvironmentVariables();
String identityEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
String identityHeader = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_HEADER);
String identityServerThumbprint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_SERVER_THUMBPRINT);


if (StringHelper.isNullOrBlank(identityEndpoint) || StringHelper.isNullOrBlank(identityHeader) || StringHelper.isNullOrBlank(identityServerThumbprint))
{
LOG.info("[Managed Identity] Service fabric managed identity is unavailable.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,14 @@ public static Stream<Arguments> createDataError() {
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT),
Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint));
}

public static Stream<Arguments> createDataGetSource() {
return Stream.of(
Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, ManagedIdentitySourceType.AZURE_ARC),
Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, ManagedIdentitySourceType.APP_SERVICE),
Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, ManagedIdentitySourceType.CLOUD_SHELL),
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT, ManagedIdentitySourceType.DEFAULT_TO_IMDS),
Arguments.of(ManagedIdentitySourceType.IMDS, "", ManagedIdentitySourceType.DEFAULT_TO_IMDS),
Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, ManagedIdentitySourceType.SERVICE_FABRIC));
}
}
Loading

0 comments on commit 9a07472

Please sign in to comment.