Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/msal-node/apiReview/msal-node.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ export { LogLevel }
export class ManagedIdentityApplication {
constructor(configuration?: ManagedIdentityConfiguration);
acquireToken(managedIdentityRequestParams: ManagedIdentityRequestParams): Promise<AuthenticationResult>;
getManagedIdentitySource(): ManagedIdentitySourceNames;
getManagedIdentitySource(): Promise<ManagedIdentitySourceNames>;
}

// @public (undocumented)
Expand Down
6 changes: 3 additions & 3 deletions lib/msal-node/src/client/ManagedIdentityApplication.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ export class ManagedIdentityApplication {
*/
if (managedIdentityRequest.claims) {
const sourceName: ManagedIdentitySourceNames =
this.managedIdentityClient.getManagedIdentitySource();
await this.managedIdentityClient.getManagedIdentitySource();

/*
* Check if there is a cached token and if the Managed Identity source supports token revocation.
Expand Down Expand Up @@ -257,10 +257,10 @@ export class ManagedIdentityApplication {
* Determine the Managed Identity Source based on available environment variables. This API is consumed by Azure Identity SDK.
* @returns ManagedIdentitySourceNames - The Managed Identity source's name
*/
public getManagedIdentitySource(): ManagedIdentitySourceNames {
public async getManagedIdentitySource(): Promise<ManagedIdentitySourceNames> {
return (
ManagedIdentityClient.sourceName ||
this.managedIdentityClient.getManagedIdentitySource()
(await this.managedIdentityClient.getManagedIdentitySource())
);
}
}
17 changes: 10 additions & 7 deletions lib/msal-node/src/client/ManagedIdentityClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export class ManagedIdentityClient {
): Promise<AuthenticationResult> {
if (!ManagedIdentityClient.identitySource) {
ManagedIdentityClient.identitySource =
this.selectManagedIdentitySource(
await this.selectManagedIdentitySource(
this.logger,
this.nodeStorage,
this.networkClient,
Expand Down Expand Up @@ -96,7 +96,7 @@ export class ManagedIdentityClient {
* This API is consumed by ManagedIdentityApplication's getManagedIdentitySource.
* @returns ManagedIdentitySourceNames - The Managed Identity source's name
*/
public getManagedIdentitySource(): ManagedIdentitySourceNames {
public async getManagedIdentitySource(): Promise<ManagedIdentitySourceNames> {
ManagedIdentityClient.sourceName =
this.allEnvironmentVariablesAreDefined(
ServiceFabric.getEnvironmentVariables()
Expand All @@ -118,7 +118,10 @@ export class ManagedIdentityClient {
AzureArc.getEnvironmentVariables()
)
? ManagedIdentitySourceNames.AZURE_ARC
: ImdsV2.isCredentialEndpointAvailable()
: (await ImdsV2.isCredentialEndpointAvailable(
this.logger,
this.networkClient
))
? ManagedIdentitySourceNames.IMDSV2
: ManagedIdentitySourceNames.DEFAULT_TO_IMDS;

Expand All @@ -129,14 +132,14 @@ export class ManagedIdentityClient {
* Tries to create a managed identity source for all sources
* @returns the managed identity Source
*/
private selectManagedIdentitySource(
private async selectManagedIdentitySource(
logger: Logger,
nodeStorage: NodeStorage,
networkClient: INetworkModule,
cryptoProvider: CryptoProvider,
disableInternalRetries: boolean,
managedIdentityId: ManagedIdentityId
): BaseManagedIdentitySource {
): Promise<BaseManagedIdentitySource> {
const source =
ServiceFabric.tryCreate(
logger,
Expand Down Expand Up @@ -176,13 +179,13 @@ export class ManagedIdentityClient {
disableInternalRetries,
managedIdentityId
) ||
ImdsV2.tryCreate(
(await ImdsV2.tryCreate(
logger,
nodeStorage,
networkClient,
cryptoProvider,
disableInternalRetries
) ||
)) ||
Imds.tryCreate(
logger,
nodeStorage,
Expand Down
8 changes: 5 additions & 3 deletions lib/msal-node/src/client/ManagedIdentitySources/Imds.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ export class Imds extends BaseManagedIdentitySource {
cryptoProvider: CryptoProvider,
disableInternalRetries: boolean
): Imds {
const validatedIdentityEndpoint: string =
this.getValidatedIdentityEndpoint(IMDS_TOKEN_PATH, logger);
const validatedIdentityEndpoint: string = this.getValidatedEndpoint(
IMDS_TOKEN_PATH,
logger
);

return new Imds(
logger,
Expand Down Expand Up @@ -134,7 +136,7 @@ export class Imds extends BaseManagedIdentitySource {
return request;
}

public static getValidatedIdentityEndpoint = (
public static getValidatedEndpoint = (
subPath: string,
logger: Logger
): string => {
Expand Down
94 changes: 78 additions & 16 deletions lib/msal-node/src/client/ManagedIdentitySources/ImdsV2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
* Licensed under the MIT License.
*/

import { INetworkModule, Logger } from "@azure/msal-common/node";
import {
INetworkModule,
Logger,
NetworkResponse,
} from "@azure/msal-common/node";
// import { Agent } from "https";
import { ManagedIdentityId } from "../../config/ManagedIdentityId.js";
import { ManagedIdentityRequestParameters } from "../../config/ManagedIdentityRequestParameters.js";
Expand All @@ -18,20 +22,27 @@ import {
import { NodeStorage } from "../../cache/NodeStorage.js";
import { Imds, IMDS_API_VERSION } from "./Imds.js";
import { ShortLivedCredential } from "../../response/ShortLivedCredentialResponse.js";
import { HttpClientWithRetries } from "../../network/HttpClientWithRetries.js";
import { DefaultManagedIdentityRetryPolicy } from "../../retry/DefaultManagedIdentityRetryPolicy.js";

const CREDENTIAL_PATH: string =
export const CREDENTIAL_PATH: string =
"/metadata/identity/credential?cred-api-version=1.0";

export interface CredentialEndpointProbeResponse {
error: string;
error_description: string;
}

export class ImdsV2 extends BaseManagedIdentitySource {
private identityEndpoint: string;
private credentialEndpoint: string;

constructor(
logger: Logger,
nodeStorage: NodeStorage,
networkClient: INetworkModule,
cryptoProvider: CryptoProvider,
disableInternalRetries: boolean,
identityEndpoint: string
credentialEndpoint: string
) {
super(
logger,
Expand All @@ -41,36 +52,87 @@ export class ImdsV2 extends BaseManagedIdentitySource {
disableInternalRetries
);

this.identityEndpoint = identityEndpoint;
this.credentialEndpoint = credentialEndpoint;
}

public static tryCreate(
public static async tryCreate(
logger: Logger,
nodeStorage: NodeStorage,
networkClient: INetworkModule,
cryptoProvider: CryptoProvider,
disableInternalRetries: boolean
): ImdsV2 | null {
if (!this.isCredentialEndpointAvailable()) {
): Promise<ImdsV2 | null> {
const validatedCredentialEndpoint: string = Imds.getValidatedEndpoint(
CREDENTIAL_PATH,
logger
);

if (
!(await this.isCredentialEndpointAvailable(
logger,
networkClient,
validatedCredentialEndpoint
))
) {
return null;
}

const validatedIdentityEndpoint: string =
Imds.getValidatedIdentityEndpoint(CREDENTIAL_PATH, logger);

return new ImdsV2(
logger,
nodeStorage,
networkClient,
cryptoProvider,
disableInternalRetries,
validatedIdentityEndpoint
validatedCredentialEndpoint
);
}

public static isCredentialEndpointAvailable(): boolean {
// TODO: Probe credential endpoint. If it doesn't return 400, return null
return false;
public static async isCredentialEndpointAvailable(
logger: Logger,
networkClient: INetworkModule,
credentialEndpoint?: string // only passed in from tryCreate in this class
): Promise<boolean> {
const validatedCredentialEndpoint: string =
credentialEndpoint ||
Imds.getValidatedEndpoint(CREDENTIAL_PATH, logger);

const networkClientWithRetry: INetworkModule =
new HttpClientWithRetries(
networkClient,
/*
* TODO: create probe credential endpoint retry policy that extends DefaultManagedIdentityRetryPolicy,
* that only retries on 400 and 500
*/
new DefaultManagedIdentityRetryPolicy(),
logger
);

const response: NetworkResponse<CredentialEndpointProbeResponse> =
await networkClientWithRetry.sendPostRequestAsync<CredentialEndpointProbeResponse>(
validatedCredentialEndpoint,
{ body: "." }
);

if (response.status !== 400) {
return false;
}

/*
* Match "IMDS/" at start of "server" header string (`^IMDS\/`)
* Match the first three numbers with dots (`\d+.\d+.\d+.`)
* Capture the last number in a group (`(\d+)`)
* Ensure end of string (`$`)
*
* Example:
* [
* "IMDS/150.870.65.1556", // index 0: full match
* "1556" // index 1: captured group (\d+)
* ]
*/
const versionMatch = response.headers["server"]?.match(
/^IMDS\/\d+\.\d+\.\d+\.(\d+)$/
);
return Boolean(versionMatch && parseInt(versionMatch[1], 10) > 1324); // .match can return null, so Boolean() is needed
}

public createRequest(
Expand All @@ -80,7 +142,7 @@ export class ImdsV2 extends BaseManagedIdentitySource {
const imdsRequest: ManagedIdentityRequestParameters =
new ManagedIdentityRequestParameters(
HttpMethod.POST,
this.identityEndpoint
this.credentialEndpoint
);

imdsRequest.headers[ManagedIdentityHeaders.METADATA_HEADER_NAME] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ describe("Acquires a token successfully via an App Service Managed Identity", ()

const managedIdentityApplication: ManagedIdentityApplication =
new ManagedIdentityApplication(userAssignedClientIdConfig);
expect(managedIdentityApplication.getManagedIdentitySource()).toBe(
ManagedIdentitySourceNames.APP_SERVICE
);
expect(
await managedIdentityApplication.getManagedIdentitySource()
).toBe(ManagedIdentitySourceNames.APP_SERVICE);

const networkManagedIdentityResult: AuthenticationResult =
await managedIdentityApplication.acquireToken(
Expand Down Expand Up @@ -100,9 +100,9 @@ describe("Acquires a token successfully via an App Service Managed Identity", ()

const managedIdentityApplication: ManagedIdentityApplication =
new ManagedIdentityApplication(userAssignedResourceIdConfig);
expect(managedIdentityApplication.getManagedIdentitySource()).toBe(
ManagedIdentitySourceNames.APP_SERVICE
);
expect(
await managedIdentityApplication.getManagedIdentitySource()
).toBe(ManagedIdentitySourceNames.APP_SERVICE);

const networkManagedIdentityResult: AuthenticationResult =
await managedIdentityApplication.acquireToken(
Expand Down Expand Up @@ -133,13 +133,13 @@ describe("Acquires a token successfully via an App Service Managed Identity", ()

describe("System Assigned", () => {
let managedIdentityApplication: ManagedIdentityApplication;
beforeEach(() => {
beforeEach(async () => {
managedIdentityApplication = new ManagedIdentityApplication(
systemAssignedConfig
);
expect(managedIdentityApplication.getManagedIdentitySource()).toBe(
ManagedIdentitySourceNames.APP_SERVICE
);
expect(
await managedIdentityApplication.getManagedIdentitySource()
).toBe(ManagedIdentitySourceNames.APP_SERVICE);
});

test("acquires a token", async () => {
Expand Down Expand Up @@ -193,9 +193,9 @@ describe("Acquires a token successfully via an App Service Managed Identity", ()

const managedIdentityApplication: ManagedIdentityApplication =
new ManagedIdentityApplication(systemAssignedConfig);
expect(managedIdentityApplication.getManagedIdentitySource()).toBe(
ManagedIdentitySourceNames.APP_SERVICE
);
expect(
await managedIdentityApplication.getManagedIdentitySource()
).toBe(ManagedIdentitySourceNames.APP_SERVICE);

let serverError: ServerError = new ServerError();
try {
Expand Down
26 changes: 13 additions & 13 deletions lib/msal-node/test/client/ManagedIdentitySources/AzureArc.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ describe("Acquires a token successfully via an Azure Arc Managed Identity", () =
// Azure Arc Managed Identities can only be system assigned
describe("System Assigned", () => {
let managedIdentityApplication: ManagedIdentityApplication;
beforeEach(() => {
beforeEach(async () => {
managedIdentityApplication = new ManagedIdentityApplication(
systemAssignedConfig
);
expect(managedIdentityApplication.getManagedIdentitySource()).toBe(
ManagedIdentitySourceNames.AZURE_ARC
);
expect(
await managedIdentityApplication.getManagedIdentitySource()
).toBe(ManagedIdentitySourceNames.AZURE_ARC);
});

test("acquires a token", async () => {
Expand Down Expand Up @@ -135,7 +135,7 @@ describe("Acquires a token successfully via an Azure Arc Managed Identity", () =
// and accessSyncSpy still returns an error
// (meaning either the himds file doesn't exists or its permissions don't allow it to be read)
expect(
managedIdentityApplication.getManagedIdentitySource()
await managedIdentityApplication.getManagedIdentitySource()
).not.toBe(ManagedIdentitySourceNames.AZURE_ARC);
// delete value cached from getManagedIdentitySource() directly above
delete ManagedIdentityClient["sourceName"];
Expand All @@ -146,9 +146,9 @@ describe("Acquires a token successfully via an Azure Arc Managed Identity", () =
return undefined;
});

expect(managedIdentityApplication.getManagedIdentitySource()).toBe(
ManagedIdentitySourceNames.AZURE_ARC
);
expect(
await managedIdentityApplication.getManagedIdentitySource()
).toBe(ManagedIdentitySourceNames.AZURE_ARC);

// returns undefined when the himds file exists and its permissions allow it to be read
// otherwise, throws an error
Expand Down Expand Up @@ -253,20 +253,20 @@ describe("Acquires a token successfully via an Azure Arc Managed Identity", () =

describe("Errors", () => {
let managedIdentityApplication: ManagedIdentityApplication;
beforeEach(() => {
beforeEach(async () => {
managedIdentityApplication = new ManagedIdentityApplication(
systemAssignedConfig
);
expect(managedIdentityApplication.getManagedIdentitySource()).toBe(
ManagedIdentitySourceNames.AZURE_ARC
);
expect(
await managedIdentityApplication.getManagedIdentitySource()
).toBe(ManagedIdentitySourceNames.AZURE_ARC);
});

test("throws an error if a user assigned managed identity is used", async () => {
const userAssignedManagedIdentityApplication: ManagedIdentityApplication =
new ManagedIdentityApplication(userAssignedClientIdConfig);
expect(
userAssignedManagedIdentityApplication.getManagedIdentitySource()
await userAssignedManagedIdentityApplication.getManagedIdentitySource()
).toBe(ManagedIdentitySourceNames.AZURE_ARC);

await expect(
Expand Down
Loading