Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class AcquireTokenByManagedIdentitySupplier extends AuthenticationResultSupplier

private static final Logger LOG = LoggerFactory.getLogger(AcquireTokenByManagedIdentitySupplier.class);

private static final int TWO_HOURS = 2*3600;

private ManagedIdentityParameters managedIdentityParameters;

AcquireTokenByManagedIdentitySupplier(ManagedIdentityApplication managedIdentityApplication, MsalRequest msalRequest) {
Expand Down Expand Up @@ -93,15 +95,27 @@ private AuthenticationResult fetchNewAccessTokenAndSaveToCache(TokenRequestExecu
}

private AuthenticationResult createFromManagedIdentityResponse(ManagedIdentityResponse managedIdentityResponse) {
long expiresOn = Long.valueOf(managedIdentityResponse.expiresOn);
long refreshOn = expiresOn > 2 * 3600 ? (expiresOn / 2) : 0L;
long expiresOn = Long.parseLong(managedIdentityResponse.expiresOn);
long refreshOn = calculateRefreshOn(expiresOn);
AuthenticationResultMetadata metadata = AuthenticationResultMetadata.builder()
.refreshOn(refreshOn)
.build();

return AuthenticationResult.builder()
.accessToken(managedIdentityResponse.getAccessToken())
.scopes(managedIdentityParameters.resource())
.expiresOn(expiresOn)
.extExpiresOn(0)
.refreshOn(refreshOn)
.metadata(metadata)
.build();
}

private long calculateRefreshOn(long expiresOn){
long timestampSeconds = System.currentTimeMillis() / 1000;
long expiresIn = expiresOn - timestampSeconds;

//The refreshOn value should be half the value of the token lifetime, if the lifetime is greater than two hours
return expiresIn > TWO_HOURS ? (expiresIn / 2) + timestampSeconds : 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class ManagedIdentityApplication extends AbstractApplicationBase implemen
static TokenCache sharedTokenCache = new TokenCache();

@Getter(value = AccessLevel.PUBLIC)
static ManagedIdentitySourceType managedIdentitySource = ManagedIdentityClient.getManagedIdentitySource();
ManagedIdentitySourceType managedIdentitySource = ManagedIdentityClient.getManagedIdentitySource();

@Getter(value = AccessLevel.PACKAGE)
static IEnvironmentVariables environmentVariables;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,24 @@
class ManagedIdentityClient {
private static final Logger LOG = LoggerFactory.getLogger(ManagedIdentityClient.class);

private static ManagedIdentitySourceType managedIdentitySourceType;

protected static void resetManagedIdentitySourceType() {
managedIdentitySourceType = ManagedIdentitySourceType.NONE;
}

static ManagedIdentitySourceType getManagedIdentitySource() {
if (managedIdentitySourceType != null && managedIdentitySourceType != ManagedIdentitySourceType.NONE) {
return managedIdentitySourceType;
}

IEnvironmentVariables environmentVariables = AbstractManagedIdentitySource.getEnvironmentVariables();

if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT)) &&
!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_HEADER))) {
if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_SERVER_THUMBPRINT))) {
managedIdentitySourceType = ManagedIdentitySourceType.SERVICE_FABRIC;
return ManagedIdentitySourceType.SERVICE_FABRIC;
} else {
managedIdentitySourceType = ManagedIdentitySourceType.APP_SERVICE;
return ManagedIdentitySourceType.APP_SERVICE;
}
} else if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT))) {
managedIdentitySourceType = ManagedIdentitySourceType.CLOUD_SHELL;
return ManagedIdentitySourceType.CLOUD_SHELL;
} else if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT)) &&
!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT))) {
managedIdentitySourceType = ManagedIdentitySourceType.AZURE_ARC;
return ManagedIdentitySourceType.AZURE_ARC;
} else {
managedIdentitySourceType = ManagedIdentitySourceType.DEFAULT_TO_IMDS;
return ManagedIdentitySourceType.DEFAULT_TO_IMDS;
}

return managedIdentitySourceType;
}

AbstractManagedIdentitySource managedIdentitySource;
Expand All @@ -64,11 +52,7 @@ ManagedIdentityResponse getManagedIdentityResponse(ManagedIdentityParameters par
private static AbstractManagedIdentitySource createManagedIdentitySource(MsalRequest msalRequest,
ServiceBundle serviceBundle) {

if (managedIdentitySourceType == null || managedIdentitySourceType == ManagedIdentitySourceType.NONE) {
managedIdentitySourceType = getManagedIdentitySource();
}

switch (managedIdentitySourceType) {
switch (getManagedIdentitySource()) {
case SERVICE_FABRIC:
return ServiceFabricManagedIdentitySource.create(msalRequest, serviceBundle);
case APP_SERVICE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class ManagedIdentityTests {
private static ManagedIdentityApplication miApp;

private String getSuccessfulResponse(String resource) {
long expiresOn = Instant.now().plus(1, ChronoUnit.HOURS).getEpochSecond();
long expiresOn = (System.currentTimeMillis() / 1000) + (24 * 3600);//A long-lived, 24 hour token
return "{\"access_token\":\"accesstoken\",\"expires_on\":\"" + expiresOn + "\",\"resource\":\"" + resource + "\",\"token_type\":" +
"\"Bearer\",\"client_id\":\"client_id\"}";
}
Expand Down Expand Up @@ -155,18 +155,22 @@ private HttpResponse expectedResponse(int statusCode, String response) {
void managedIdentity_GetManagedIdentitySource(ManagedIdentitySourceType source, String endpoint, ManagedIdentitySourceType expectedSource) {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();

ManagedIdentitySourceType managedIdentitySourceType = ManagedIdentityClient.getManagedIdentitySource();
assertEquals(expectedSource, managedIdentitySourceType);
miApp = ManagedIdentityApplication
.builder(ManagedIdentityId.systemAssigned())
.build();

ManagedIdentitySourceType miClientSourceType = ManagedIdentityClient.getManagedIdentitySource();
ManagedIdentitySourceType miAppSourceType = miApp.managedIdentitySource;
assertEquals(expectedSource, miClientSourceType);
assertEquals(expectedSource, miAppSourceType);
}

@ParameterizedTest
@MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createData")
void managedIdentityTest_SystemAssigned_SuccessfulResponse(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
Expand Down Expand Up @@ -201,7 +205,6 @@ void managedIdentityTest_SystemAssigned_SuccessfulResponse(ManagedIdentitySource
void managedIdentityTest_UserAssigned_SuccessfulResponse(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(expectedRequest(source, resource, id))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
Expand All @@ -222,12 +225,38 @@ void managedIdentityTest_UserAssigned_SuccessfulResponse(ManagedIdentitySourceTy
verify(httpClientMock, times(1)).send(any());
}

@Test
void managedIdentityTest_RefreshOnHalfOfExpiresOn() throws Exception {
//All managed identity flows use the same AcquireTokenByManagedIdentitySupplier where refreshOn is set,
// so any of the MI options should let us verify that it's being set correctly
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.APP_SERVICE, appServiceEndpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(expectedRequest(ManagedIdentitySourceType.APP_SERVICE, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));

miApp = ManagedIdentityApplication
.builder(ManagedIdentityId.systemAssigned())
.httpClient(httpClientMock)
.build();

AuthenticationResult result = (AuthenticationResult) miApp.acquireTokenForManagedIdentity(
ManagedIdentityParameters.builder(resource)
.build()).get();

long timestampSeconds = (System.currentTimeMillis() / 1000);

assertNotNull(result.accessToken());
assertEquals((result.expiresOn() - timestampSeconds)/2, result.refreshOn() - timestampSeconds);

verify(httpClientMock, times(1)).send(any());
}

@ParameterizedTest
@MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataUserAssignedNotSupported")
void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

miApp = ManagedIdentityApplication
Expand Down Expand Up @@ -264,7 +293,6 @@ void managedIdentityTest_DifferentScopes_RequestsNewToken(ManagedIdentitySourceT

IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
Expand Down Expand Up @@ -298,7 +326,6 @@ void managedIdentityTest_DifferentScopes_RequestsNewToken(ManagedIdentitySourceT
void managedIdentityTest_WrongScopes(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

if (environmentVariables.getEnvironmentVariable("SourceType").equals(ManagedIdentitySourceType.CLOUD_SHELL.toString())) {
Expand Down Expand Up @@ -337,7 +364,6 @@ void managedIdentityTest_WrongScopes(ManagedIdentitySourceType source, String en
void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

miApp = ManagedIdentityApplication
Expand Down Expand Up @@ -388,7 +414,6 @@ void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint
void managedIdentity_RequestFailed_NoPayload(ManagedIdentitySourceType source, String endpoint) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, ""));
Expand Down Expand Up @@ -423,7 +448,6 @@ void managedIdentity_RequestFailed_NoPayload(ManagedIdentitySourceType source, S
void managedIdentity_RequestFailed_NullResponse(ManagedIdentitySourceType source, String endpoint) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, ""));
Expand Down Expand Up @@ -458,7 +482,6 @@ void managedIdentity_RequestFailed_NullResponse(ManagedIdentitySourceType source
void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType source, String endpoint) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(expectedRequest(source, resource))).thenThrow(new SocketException("A socket operation was attempted to an unreachable network."));
Expand Down Expand Up @@ -492,7 +515,6 @@ void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType
void azureArcManagedIdentity_MissingAuthHeader() throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AZURE_ARC, azureArcEndpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

HttpResponse response = new HttpResponse();
Expand Down Expand Up @@ -531,7 +553,6 @@ void azureArcManagedIdentity_MissingAuthHeader() throws Exception {
void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoint) throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
Expand Down Expand Up @@ -572,7 +593,6 @@ void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoi
void azureArcManagedIdentity_InvalidAuthHeader() throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AZURE_ARC, azureArcEndpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

HttpResponse response = new HttpResponse();
Expand Down Expand Up @@ -611,7 +631,6 @@ void azureArcManagedIdentity_InvalidAuthHeader() throws Exception {
void azureArcManagedIdentityAuthheaderValidationTest() throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AZURE_ARC, azureArcEndpoint);
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
ManagedIdentityClient.resetManagedIdentitySourceType();
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

//Both a missing file and an invalid path structure should throw an exception
Expand Down