From 9a07472cedbde7a1013a6fc81da91f2a2c6d67a1 Mon Sep 17 00:00:00 2001 From: Neha Bhargava <61847233+neha-bhargava@users.noreply.github.com> Date: Wed, 29 May 2024 14:22:53 -0700 Subject: [PATCH] Add API to get managed identity source --- .../msal4j/AbstractManagedIdentitySource.java | 5 +- .../AppServiceManagedIdentitySource.java | 2 +- .../msal4j/AzureArcManagedIdentitySource.java | 2 +- .../CloudShellManagedIdentitySource.java | 2 +- .../aad/msal4j/IMDSManagedIdentitySource.java | 4 +- .../msal4j/ManagedIdentityApplication.java | 12 ++++ .../aad/msal4j/ManagedIdentityClient.java | 59 +++++++++++++++---- .../aad/msal4j/ManagedIdentityParameters.java | 6 -- .../aad/msal4j/ManagedIdentitySourceType.java | 14 +++-- .../ServiceFabricManagedIdentitySource.java | 3 +- .../ManagedIdentityTestDataProvider.java | 10 ++++ .../aad/msal4j/ManagedIdentityTests.java | 54 +++++++++++------ 12 files changed, 124 insertions(+), 49 deletions(-) diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AbstractManagedIdentitySource.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AbstractManagedIdentitySource.java index a6160cf2..260936b5 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AbstractManagedIdentitySource.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AbstractManagedIdentitySource.java @@ -133,7 +133,8 @@ protected String getMessageFromErrorResponse(IHttpResponse response) { managedIdentityErrorResponse.getError(), managedIdentityErrorResponse.getErrorDescription()); } - protected static IEnvironmentVariables getEnvironmentVariables(ManagedIdentityParameters parameters) { - return parameters.environmentVariables == null ? new EnvironmentVariables() : parameters.environmentVariables; + protected static IEnvironmentVariables getEnvironmentVariables() { + return ManagedIdentityApplication.environmentVariables == null ? + new EnvironmentVariables() : ManagedIdentityApplication.environmentVariables; } } diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AppServiceManagedIdentitySource.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AppServiceManagedIdentitySource.java index 843b9b61..2b26b109 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AppServiceManagedIdentitySource.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AppServiceManagedIdentitySource.java @@ -56,7 +56,7 @@ private AppServiceManagedIdentitySource(MsalRequest msalRequest, ServiceBundle s static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) { - IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters()); + IEnvironmentVariables environmentVariables = getEnvironmentVariables(); String msiSecret = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_HEADER); String msiEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT); diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AzureArcManagedIdentitySource.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AzureArcManagedIdentitySource.java index ebedb690..f3696b7d 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AzureArcManagedIdentitySource.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AzureArcManagedIdentitySource.java @@ -27,7 +27,7 @@ class AzureArcManagedIdentitySource extends AbstractManagedIdentitySource{ static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) { - IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters()); + IEnvironmentVariables environmentVariables = getEnvironmentVariables(); String identityEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT); String imdsEndpoint = environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT); diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/CloudShellManagedIdentitySource.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/CloudShellManagedIdentitySource.java index 6e337d6c..b0563676 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/CloudShellManagedIdentitySource.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/CloudShellManagedIdentitySource.java @@ -46,7 +46,7 @@ private CloudShellManagedIdentitySource(MsalRequest msalRequest, ServiceBundle s static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) { - IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters()); + IEnvironmentVariables environmentVariables = getEnvironmentVariables(); String msiEndpoint = environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT); diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IMDSManagedIdentitySource.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IMDSManagedIdentitySource.java index ce7dfd51..0dc2d2b1 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IMDSManagedIdentitySource.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IMDSManagedIdentitySource.java @@ -35,9 +35,7 @@ public IMDSManagedIdentitySource(MsalRequest msalRequest, ServiceBundle serviceBundle) { super(msalRequest, serviceBundle, ManagedIdentitySourceType.IMDS); ManagedIdentityParameters parameters = (ManagedIdentityParameters) msalRequest.requestContext().apiParameters(); - IEnvironmentVariables environmentVariables = ((ManagedIdentityParameters) msalRequest.requestContext().apiParameters()).environmentVariables == null ? - new EnvironmentVariables() : - parameters.environmentVariables; + IEnvironmentVariables environmentVariables = getEnvironmentVariables(); if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.AZURE_POD_IDENTITY_AUTHORITY_HOST))){ LOG.info(String.format("[Managed Identity] Environment variable AZURE_POD_IDENTITY_AUTHORITY_HOST for IMDS returned endpoint: %s", environmentVariables.getEnvironmentVariable(Constants.AZURE_POD_IDENTITY_AUTHORITY_HOST))); try { diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityApplication.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityApplication.java index fcb4a58e..a260018a 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityApplication.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityApplication.java @@ -3,7 +3,9 @@ package com.microsoft.aad.msal4j; +import lombok.AccessLevel; import lombok.Getter; +import lombok.Setter; import org.slf4j.LoggerFactory; import java.util.concurrent.CompletableFuture; @@ -22,6 +24,16 @@ public class ManagedIdentityApplication extends AbstractApplicationBase implemen @Getter static TokenCache sharedTokenCache = new TokenCache(); + @Getter(value = AccessLevel.PUBLIC) + static ManagedIdentitySourceType managedIdentitySource = ManagedIdentityClient.getManagedIdentitySource(); + + @Getter(value = AccessLevel.PACKAGE) + static IEnvironmentVariables environmentVariables; + + static void setEnvironmentVariables(IEnvironmentVariables environmentVariables) { + ManagedIdentityApplication.environmentVariables = environmentVariables; + } + private ManagedIdentityApplication(Builder builder) { super(builder); diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityClient.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityClient.java index 7f6b2be1..334abec8 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityClient.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityClient.java @@ -3,6 +3,8 @@ package com.microsoft.aad.msal4j; +import lombok.AccessLevel; +import lombok.Getter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -12,6 +14,37 @@ class ManagedIdentityClient { private static final Logger LOG = LoggerFactory.getLogger(ManagedIdentityClient.class); + private static ManagedIdentitySourceType managedIdentitySourceType; + + protected static void resetManagedIdentitySourceType() { + managedIdentitySourceType = ManagedIdentitySourceType.NONE; + } + + static ManagedIdentitySourceType getManagedIdentitySource() { + if (managedIdentitySourceType != null && managedIdentitySourceType != ManagedIdentitySourceType.NONE) { + return managedIdentitySourceType; + } + + IEnvironmentVariables environmentVariables = AbstractManagedIdentitySource.getEnvironmentVariables(); + + if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT)) && + !StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_HEADER))) { + if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_SERVER_THUMBPRINT))) { + managedIdentitySourceType = ManagedIdentitySourceType.SERVICE_FABRIC; + } else + managedIdentitySourceType = ManagedIdentitySourceType.APP_SERVICE; + } else if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT))) { + managedIdentitySourceType = ManagedIdentitySourceType.CLOUD_SHELL; + } else if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT)) && + !StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT))) { + managedIdentitySourceType = ManagedIdentitySourceType.AZURE_ARC; + } else { + managedIdentitySourceType = ManagedIdentitySourceType.DEFAULT_TO_IMDS; + } + + return managedIdentitySourceType; + } + AbstractManagedIdentitySource managedIdentitySource; ManagedIdentityClient(MsalRequest msalRequest, ServiceBundle serviceBundle) { @@ -38,16 +71,22 @@ ManagedIdentityResponse getManagedIdentityResponse(ManagedIdentityParameters par private static AbstractManagedIdentitySource createManagedIdentitySource(MsalRequest msalRequest, ServiceBundle serviceBundle) { AbstractManagedIdentitySource managedIdentitySource; - if ((managedIdentitySource = ServiceFabricManagedIdentitySource.create(msalRequest, serviceBundle)) != null) { - return managedIdentitySource; - } else if ((managedIdentitySource = AppServiceManagedIdentitySource.create(msalRequest, serviceBundle)) != null) { - return managedIdentitySource; - } else if ((managedIdentitySource = CloudShellManagedIdentitySource.create(msalRequest, serviceBundle)) != null) { - return managedIdentitySource; - } else if ((managedIdentitySource = AzureArcManagedIdentitySource.create(msalRequest, serviceBundle)) != null) { - return managedIdentitySource; - } else { - return new IMDSManagedIdentitySource(msalRequest, serviceBundle); + + if (managedIdentitySourceType == null || managedIdentitySourceType == ManagedIdentitySourceType.NONE) { + managedIdentitySourceType = getManagedIdentitySource(); + } + + switch (managedIdentitySourceType) { + case SERVICE_FABRIC: + return ServiceFabricManagedIdentitySource.create(msalRequest, serviceBundle); + case APP_SERVICE: + return AppServiceManagedIdentitySource.create(msalRequest, serviceBundle); + case CLOUD_SHELL: + return CloudShellManagedIdentitySource.create(msalRequest, serviceBundle); + case AZURE_ARC: + return AzureArcManagedIdentitySource.create(msalRequest, serviceBundle); + default: + return new IMDSManagedIdentitySource(msalRequest, serviceBundle); } } } \ No newline at end of file diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityParameters.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityParameters.java index 7f38c61c..f994e495 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityParameters.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityParameters.java @@ -27,8 +27,6 @@ public class ManagedIdentityParameters implements IAcquireTokenParameters { boolean forceRefresh; - IEnvironmentVariables environmentVariables; - @Override public Set scopes() { return null; @@ -54,10 +52,6 @@ public Map extraQueryParameters() { return null; } - void setEnvironmentVariablesConfig(IEnvironmentVariables environmentVariables) { - this.environmentVariables = environmentVariables; - } - private static ManagedIdentityParametersBuilder builder() { return new ManagedIdentityParametersBuilder(); } diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentitySourceType.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentitySourceType.java index 66bddf6a..14a3b6e0 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentitySourceType.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentitySourceType.java @@ -6,14 +6,16 @@ enum ManagedIdentitySourceType { // Default. NONE, - // The source to acquire token for managed identity is IMDS. + // The source used to acquire token for managed identity is IMDS. IMDS, - // The source to acquire token for managed identity is App Service. + // The source used to acquire token for managed identity is App Service. APP_SERVICE, - // The source to acquire token for managed identity is Azure Arc. + // The source used to acquire token for managed identity is Azure Arc. AZURE_ARC, - // The source to acquire token for managed identity is Cloud Shell. + // The source used to acquire token for managed identity is Cloud Shell. CLOUD_SHELL, - // The source to acquire token for managed identity is Service Fabric. - SERVICE_FABRIC + // The source used to acquire token for managed identity is Service Fabric. + SERVICE_FABRIC, + // The source to acquire token for managed identity is defaulted to IMDS when no environment variables are set. + DEFAULT_TO_IMDS } diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ServiceFabricManagedIdentitySource.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ServiceFabricManagedIdentitySource.java index 69aace93..804eb486 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ServiceFabricManagedIdentitySource.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ServiceFabricManagedIdentitySource.java @@ -95,12 +95,11 @@ public ManagedIdentityResponse getManagedIdentityResponse( static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) { - IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters()); + IEnvironmentVariables environmentVariables = getEnvironmentVariables(); String identityEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT); String identityHeader = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_HEADER); String identityServerThumbprint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_SERVER_THUMBPRINT); - if (StringHelper.isNullOrBlank(identityEndpoint) || StringHelper.isNullOrBlank(identityHeader) || StringHelper.isNullOrBlank(identityServerThumbprint)) { LOG.info("[Managed Identity] Service fabric managed identity is unavailable."); diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestDataProvider.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestDataProvider.java index 6ba15d7d..5132bc83 100644 --- a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestDataProvider.java +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestDataProvider.java @@ -97,4 +97,14 @@ public static Stream createDataError() { Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT), Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint)); } + + public static Stream createDataGetSource() { + return Stream.of( + Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, ManagedIdentitySourceType.AZURE_ARC), + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, ManagedIdentitySourceType.APP_SERVICE), + Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, ManagedIdentitySourceType.CLOUD_SHELL), + Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT, ManagedIdentitySourceType.DEFAULT_TO_IMDS), + Arguments.of(ManagedIdentitySourceType.IMDS, "", ManagedIdentitySourceType.DEFAULT_TO_IMDS), + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, ManagedIdentitySourceType.SERVICE_FABRIC)); + } } diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java index 47f30e90..d288804d 100644 --- a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java @@ -144,10 +144,23 @@ private HttpResponse expectedResponse(int statusCode, String response) { return httpResponse; } + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataGetSource") + void managedIdentity_GetManagedIdentitySource(ManagedIdentitySourceType source, String endpoint, ManagedIdentitySourceType expectedSource) { + IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + ManagedIdentityClient.resetManagedIdentitySourceType(); + + ManagedIdentitySourceType managedIdentitySourceType = ManagedIdentityClient.getManagedIdentitySource(); + assertEquals(expectedSource, managedIdentitySourceType); + } + @ParameterizedTest @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createData") void managedIdentityTest_SystemAssigned_SuccessfulResponse(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + ManagedIdentityClient.resetManagedIdentitySourceType(); DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); @@ -162,7 +175,6 @@ void managedIdentityTest_SystemAssigned_SuccessfulResponse(ManagedIdentitySource IAuthenticationResult result = miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) - .environmentVariables(environmentVariables) .build()).get(); assertNotNull(result.accessToken()); @@ -171,7 +183,6 @@ void managedIdentityTest_SystemAssigned_SuccessfulResponse(ManagedIdentitySource result = miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) - .environmentVariables(environmentVariables) .build()).get(); assertNotNull(result.accessToken()); @@ -183,6 +194,8 @@ void managedIdentityTest_SystemAssigned_SuccessfulResponse(ManagedIdentitySource @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataUserAssigned") void managedIdentityTest_UserAssigned_SuccessfulResponse(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception { IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + ManagedIdentityClient.resetManagedIdentitySourceType(); DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); when(httpClientMock.send(expectedRequest(source, resource, id))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); @@ -197,7 +210,6 @@ void managedIdentityTest_UserAssigned_SuccessfulResponse(ManagedIdentitySourceTy IAuthenticationResult result = miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) - .environmentVariables(environmentVariables) .build()).get(); assertNotNull(result.accessToken()); @@ -208,6 +220,8 @@ void managedIdentityTest_UserAssigned_SuccessfulResponse(ManagedIdentitySourceTy @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataUserAssignedNotSupported") void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception { IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + ManagedIdentityClient.resetManagedIdentitySourceType(); DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); miApp = ManagedIdentityApplication @@ -221,7 +235,6 @@ void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType sou try { IAuthenticationResult result = miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) - .environmentVariables(environmentVariables) .build()).get(); } catch (Exception e) { assertNotNull(e); @@ -244,6 +257,8 @@ void managedIdentityTest_DifferentScopes_RequestsNewToken(ManagedIdentitySourceT String anotherResource = "https://graph.microsoft.com"; IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + ManagedIdentityClient.resetManagedIdentitySourceType(); DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); @@ -259,14 +274,12 @@ void managedIdentityTest_DifferentScopes_RequestsNewToken(ManagedIdentitySourceT IAuthenticationResult result = miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) - .environmentVariables(environmentVariables) .build()).get(); assertNotNull(result.accessToken()); result = miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(anotherResource) - .environmentVariables(environmentVariables) .build()).get(); assertNotNull(result.accessToken()); @@ -278,6 +291,8 @@ void managedIdentityTest_DifferentScopes_RequestsNewToken(ManagedIdentitySourceT @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataWrongScope") void managedIdentityTest_WrongScopes(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + ManagedIdentityClient.resetManagedIdentitySourceType(); DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); if (environmentVariables.getEnvironmentVariable("SourceType").equals(ManagedIdentitySourceType.CLOUD_SHELL.toString())) { @@ -297,7 +312,6 @@ void managedIdentityTest_WrongScopes(ManagedIdentitySourceType source, String en try { miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) - .environmentVariables(environmentVariables) .build()).get(); } catch (Exception exception) { assert(exception.getCause() instanceof MsalServiceException); @@ -316,6 +330,8 @@ void managedIdentityTest_WrongScopes(ManagedIdentitySourceType source, String en @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataWrongScope") void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + ManagedIdentityClient.resetManagedIdentitySourceType(); DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); miApp = ManagedIdentityApplication @@ -332,7 +348,6 @@ void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint try { miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) - .environmentVariables(environmentVariables) .build()).get(); } catch (Exception exception) { assert(exception.getCause() instanceof MsalServiceException); @@ -349,7 +364,6 @@ void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint try { miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) - .environmentVariables(environmentVariables) .build()).get(); } catch (Exception exception) { assert(exception.getCause() instanceof MsalServiceException); @@ -367,6 +381,8 @@ void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") void managedIdentity_RequestFailed_NoPayload(ManagedIdentitySourceType source, String endpoint) throws Exception { IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + ManagedIdentityClient.resetManagedIdentitySourceType(); DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, "")); @@ -382,7 +398,6 @@ void managedIdentity_RequestFailed_NoPayload(ManagedIdentitySourceType source, S try { miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) - .environmentVariables(environmentVariables) .build()).get(); } catch (Exception exception) { assert(exception.getCause() instanceof MsalServiceException); @@ -401,6 +416,8 @@ void managedIdentity_RequestFailed_NoPayload(ManagedIdentitySourceType source, S @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") void managedIdentity_RequestFailed_NullResponse(ManagedIdentitySourceType source, String endpoint) throws Exception { IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + ManagedIdentityClient.resetManagedIdentitySourceType(); DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, "")); @@ -416,7 +433,6 @@ void managedIdentity_RequestFailed_NullResponse(ManagedIdentitySourceType source try { miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) - .environmentVariables(environmentVariables) .build()).get(); } catch (Exception exception) { assert(exception.getCause() instanceof MsalServiceException); @@ -435,6 +451,8 @@ void managedIdentity_RequestFailed_NullResponse(ManagedIdentitySourceType source @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType source, String endpoint) throws Exception { IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + ManagedIdentityClient.resetManagedIdentitySourceType(); DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); when(httpClientMock.send(expectedRequest(source, resource))).thenThrow(new SocketException("A socket operation was attempted to an unreachable network.")); @@ -450,7 +468,6 @@ void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType try { miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) - .environmentVariables(environmentVariables) .build()).get(); } catch (Exception exception) { assert(exception.getCause() instanceof MsalServiceException); @@ -468,6 +485,8 @@ void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType @Test void azureArcManagedIdentity_MissingAuthHeader() throws Exception { IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AZURE_ARC, azureArcEndpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + ManagedIdentityClient.resetManagedIdentitySourceType(); DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); HttpResponse response = new HttpResponse(); @@ -486,7 +505,6 @@ void azureArcManagedIdentity_MissingAuthHeader() throws Exception { try { miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) - .environmentVariables(environmentVariables) .build()).get(); } catch (Exception exception) { assert(exception.getCause() instanceof MsalServiceException); @@ -506,6 +524,8 @@ void azureArcManagedIdentity_MissingAuthHeader() throws Exception { @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoint) throws Exception { IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + ManagedIdentityClient.resetManagedIdentitySourceType(); DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); @@ -525,14 +545,12 @@ void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoi IAuthenticationResult resultMiApp1 = miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) - .environmentVariables(environmentVariables) .build()).get(); assertNotNull(resultMiApp1.accessToken()); IAuthenticationResult resultMiApp2 = miApp2.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) - .environmentVariables(environmentVariables) .build()).get(); assertNotNull(resultMiApp2.accessToken()); @@ -547,6 +565,8 @@ void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoi @Test void azureArcManagedIdentity_InvalidAuthHeader() throws Exception { IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AZURE_ARC, azureArcEndpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + ManagedIdentityClient.resetManagedIdentitySourceType(); DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); HttpResponse response = new HttpResponse(); @@ -566,7 +586,6 @@ void azureArcManagedIdentity_InvalidAuthHeader() throws Exception { try { miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) - .environmentVariables(environmentVariables) .build()).get(); } catch (Exception exception) { assert(exception.getCause() instanceof MsalServiceException); @@ -586,6 +605,8 @@ void azureArcManagedIdentity_InvalidAuthHeader() throws Exception { void azureArcManagedIdentityAuthheaderTest() throws Exception { Path path = Paths.get(this.getClass().getResource("/msi-azure-arc-secret.txt").toURI()); IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AZURE_ARC, azureArcEndpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + ManagedIdentityClient.resetManagedIdentitySourceType(); DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); // Mock 401 response that returns www-authenticate header @@ -610,7 +631,6 @@ void azureArcManagedIdentityAuthheaderTest() throws Exception { IAuthenticationResult result = miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) - .environmentVariables(environmentVariables) .build()).get(); assertNotNull(result.accessToken());