Skip to content

Commit

Permalink
Fix issue where stand-alone ManagedIdentityCredential does not consid…
Browse files Browse the repository at this point in the history
…er WorkloadIdentity (#46693)
  • Loading branch information
christothes authored Oct 17, 2024
1 parent e545aec commit 02da266
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 67 deletions.
1 change: 1 addition & 0 deletions sdk/identity/Azure.Identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
### Breaking Changes

### Bugs Fixed
- Fixed an issue that prevented ManagedIdentityCredential from attempting to detect if Workload Identity is enabled in the current environment. [#46653](https://github.com/Azure/azure-sdk-for-net/issues/46653)

### Other Changes

Expand Down
11 changes: 11 additions & 0 deletions sdk/identity/Azure.Identity/src/AzureIdentityEventSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@ internal sealed class AzureIdentityEventSource : AzureEventSource, IIdentityLogg
internal const int UnableToParseAccountDetailsFromTokenEvent = 20;
private const int UserAssignedManagedIdentityNotSupportedEvent = 21;
private const int ServiceFabricManagedIdentityRuntimeConfigurationNotSupportedEvent = 22;
private const int ManagedIdentitySourceAttemptedEvent = 25;
internal const string TenantIdDiscoveredAndNotUsedEventMessage = "A token was request for a different tenant than was configured on the credential, but the configured value was used since multi tenant authentication has been disabled. Configured TenantId: {0}, Requested TenantId {1}";
internal const string TenantIdDiscoveredAndUsedEventMessage = "A token was requested for a different tenant than was configured on the credential, and the requested tenant id was used to authenticate. Configured TenantId: {0}, Requested TenantId {1}";
internal const string AuthenticatedAccountDetailsMessage = "Client ID: {0}. Tenant ID: {1}. User Principal Name: {2} Object ID: {3}";
internal const string Unavailable = "<not available>";
internal const string UnableToParseAccountDetailsFromTokenMessage = "Unable to parse account details from the Access Token";
internal const string UserAssignedManagedIdentityNotSupportedMessage = "User assigned managed identities are not supported in the {0} environment.";
internal const string ServiceFabricManagedIdentityRuntimeConfigurationNotSupportedMessage = "Service Fabric user assigned managed identity ClientId or ResourceId is not configurable at runtime.";
internal const string ManagedIdentitySourceAttemptedMessage = "ManagedIdentitySource {0} was attempted. IsSelected={1}.";

private AzureIdentityEventSource() : base(EventSourceName) { }

Expand Down Expand Up @@ -401,5 +403,14 @@ public void ServiceFabricManagedIdentityRuntimeConfigurationNotSupported()
WriteEvent(ServiceFabricManagedIdentityRuntimeConfigurationNotSupportedEvent);
}
}

[Event(ManagedIdentitySourceAttemptedEvent, Level = EventLevel.Informational, Message = ManagedIdentitySourceAttemptedMessage)]
public void ManagedIdentitySourceAttempted(string source, bool isSelected)
{
if (IsEnabled(EventLevel.Informational, EventKeywords.All))
{
WriteEvent(ManagedIdentitySourceAttemptedEvent, source, isSelected);
}
}
}
}
70 changes: 29 additions & 41 deletions sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ internal class ManagedIdentityClient
internal Lazy<ManagedIdentitySource> _identitySource;
private MsalConfidentialClient _msalConfidentialClient;
private MsalManagedIdentityClient _msalManagedIdentityClient;
private bool _enableLegacyMI;
private bool _isChainedCredential;
private ManagedIdentityClientOptions _options;

protected ManagedIdentityClient()
{
Expand All @@ -39,12 +39,12 @@ public ManagedIdentityClient(CredentialPipeline pipeline, ResourceIdentifier res

public ManagedIdentityClient(ManagedIdentityClientOptions options)
{
_options = options.Clone();
ManagedIdentityId = options.ManagedIdentityId;
Pipeline = options.Pipeline;
_enableLegacyMI = options.EnableManagedIdentityLegacyBehavior;
_isChainedCredential = options.Options?.IsChainedCredential ?? false;
_msalManagedIdentityClient = new MsalManagedIdentityClient(options);
_identitySource = new Lazy<ManagedIdentitySource>(() => SelectManagedIdentitySource(options, _enableLegacyMI, _msalManagedIdentityClient));
_identitySource = new Lazy<ManagedIdentitySource>(() => SelectManagedIdentitySource(options, _msalManagedIdentityClient));
_msalConfidentialClient = new MsalConfidentialClient(
Pipeline,
"MANAGED-IDENTITY-RESOURCE-TENENT",
Expand All @@ -60,31 +60,33 @@ public ManagedIdentityClient(ManagedIdentityClientOptions options)
public async ValueTask<AccessToken> AuthenticateAsync(bool async, TokenRequestContext context, CancellationToken cancellationToken)
{
AuthenticationResult result;
if (_enableLegacyMI)

var availableSource = ManagedIdentityApplication.GetManagedIdentitySource();

// If the source is DefaultToImds and the credential is chained, we should probe the IMDS endpoint first.
if (availableSource == MSAL.ManagedIdentitySource.DefaultToImds && _isChainedCredential)
{
return await AuthenticateCoreAsync(async, context, cancellationToken).ConfigureAwait(false);
}

// ServiceFabric does not support specifying user-assigned managed identity by client ID or resource ID. The managed identity selected is based on the resource configuration.
if (availableSource == MSAL.ManagedIdentitySource.ServiceFabric && (ManagedIdentityId?._idType != ManagedIdentityIdType.SystemAssigned))
{
result = await _msalConfidentialClient.AcquireTokenForClientAsync(context.Scopes, context.TenantId, context.Claims, context.IsCaeEnabled, async, cancellationToken).ConfigureAwait(false);
throw new AuthenticationFailedException(Constants.MiSeviceFabricNoUserAssignedIdentityMessage);
}
else

// First try the TokenExchangeManagedIdentitySource, if it is not available, fall back to MSAL directly.
var tokenExchangeManagedIdentitySource = TokenExchangeManagedIdentitySource.TryCreate(_options);
if (default != tokenExchangeManagedIdentitySource)
{
var availableSource = ManagedIdentityApplication.GetManagedIdentitySource();

// If the source is DefaultToImds and the credential is chained, we should probe the IMDS endpoint first.
if (availableSource == MSAL.ManagedIdentitySource.DefaultToImds && _isChainedCredential)
{
return await AuthenticateCoreAsync(async, context, cancellationToken).ConfigureAwait(false);
}

// ServiceFabric does not support specifying user-assigned managed identity by client ID or resource ID. The managed identity selected is based on the resource configuration.
if (availableSource == MSAL.ManagedIdentitySource.ServiceFabric && (ManagedIdentityId?._idType != ManagedIdentityIdType.SystemAssigned))
{
throw new AuthenticationFailedException(Constants.MiSeviceFabricNoUserAssignedIdentityMessage);
}

// The default case is to use the MSAL implementation, which does no probing of the IMDS endpoint.
result = async ?
await _msalManagedIdentityClient.AcquireTokenForManagedIdentityAsync(context, cancellationToken).ConfigureAwait(false) :
_msalManagedIdentityClient.AcquireTokenForManagedIdentity(context, cancellationToken);
return await tokenExchangeManagedIdentitySource.AuthenticateAsync(async, context, cancellationToken).ConfigureAwait(false);
}

// The default case is to use the MSAL implementation, which does no probing of the IMDS endpoint.
result = async ?
await _msalManagedIdentityClient.AcquireTokenForManagedIdentityAsync(context, cancellationToken).ConfigureAwait(false) :
_msalManagedIdentityClient.AcquireTokenForManagedIdentity(context, cancellationToken);

return result.ToAccessToken();
}

Expand Down Expand Up @@ -115,24 +117,10 @@ private async Task<AppTokenProviderResult> AppTokenProviderImpl(AppTokenProvider
};
}

private static ManagedIdentitySource SelectManagedIdentitySource(ManagedIdentityClientOptions options, bool _enableLegacyMI = true, MsalManagedIdentityClient client = null)
private static ManagedIdentitySource SelectManagedIdentitySource(ManagedIdentityClientOptions options, MsalManagedIdentityClient client = null)
{
if (_enableLegacyMI)
{
return
ServiceFabricManagedIdentitySource.TryCreate(options) ??
AppServiceV2019ManagedIdentitySource.TryCreate(options) ??
AppServiceV2017ManagedIdentitySource.TryCreate(options) ??
CloudShellManagedIdentitySource.TryCreate(options) ??
AzureArcManagedIdentitySource.TryCreate(options) ??
TokenExchangeManagedIdentitySource.TryCreate(options) ??
new ImdsManagedIdentitySource(options);
}
else
{
return TokenExchangeManagedIdentitySource.TryCreate(options) ??
new ImdsManagedIdentityProbeSource(options, client);
}
return TokenExchangeManagedIdentitySource.TryCreate(options) ??
new ImdsManagedIdentityProbeSource(options, client);
}
}
}
33 changes: 30 additions & 3 deletions sdk/identity/Azure.Identity/src/ManagedIdentityClientOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,36 @@ internal class ManagedIdentityClientOptions

public bool ExcludeTokenExchangeManagedIdentitySource { get; set; }

// TODO: revert before GA
public bool EnableManagedIdentityLegacyBehavior { get; set; } = Environment.GetEnvironmentVariable("AZURE_IDENTITY_ENABLE_LEGACY_IMDS_BEHAVIOR") != null;

public bool IsForceRefreshEnabled { get; set; }

public ManagedIdentityClientOptions Clone()
{
var cloned = new ManagedIdentityClientOptions
{
ManagedIdentityId = ManagedIdentityId,
PreserveTransport = PreserveTransport,
InitialImdsConnectionTimeout = InitialImdsConnectionTimeout,
Pipeline = Pipeline,
ExcludeTokenExchangeManagedIdentitySource = ExcludeTokenExchangeManagedIdentitySource,
IsForceRefreshEnabled = IsForceRefreshEnabled,
};

if (Options != null)
{
if (Options is DefaultAzureCredentialOptions dac)
{
cloned.Options = dac.Clone<DefaultAzureCredentialOptions>();
}
else if (Options is ManagedIdentityCredentialOptions mic)
{
cloned.Options = mic.Clone<ManagedIdentityCredentialOptions>();
}
else
{
cloned.Options = Options.Clone<TokenCredentialOptions>();
}
}
return cloned;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Threading;
Expand Down Expand Up @@ -33,9 +32,11 @@ public static ManagedIdentitySource TryCreate(ManagedIdentityClientOptions optio

if (options.ExcludeTokenExchangeManagedIdentitySource || string.IsNullOrEmpty(tokenFilePath) || string.IsNullOrEmpty(tenantId) || string.IsNullOrEmpty(clientId))
{
AzureIdentityEventSource.Singleton.ManagedIdentitySourceAttempted("TokenExchangeManagedIdentitySource", false);
return default;
}

AzureIdentityEventSource.Singleton.ManagedIdentitySourceAttempted("TokenExchangeManagedIdentitySource", true);
return new TokenExchangeManagedIdentitySource(options.Pipeline, tenantId, clientId, tokenFilePath);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,6 @@ public ManagedIdentityCredentialTests(bool isAsync) : base(isAsync)

private const string ExpectedToken = "mock-msi-access-token";

[NonParallelizable]
[Test]
public async Task VerifyTokenCaching()
{
using var environment = new TestEnvVar(new() { { "AZURE_IDENTITY_ENABLE_LEGACY_IMDS_BEHAVIOR", "true" } });
int callCount = 0;

var mockClient = new MockManagedIdentityClient(CredentialPipeline.GetInstance(null))
{
TokenFactory = () => { callCount++; return new AccessToken(Guid.NewGuid().ToString(), DateTimeOffset.UtcNow.AddHours(24)); }
};

var cred = InstrumentClient(new ManagedIdentityCredential(mockClient));

for (int i = 0; i < 5; i++)
{
await cred.GetTokenAsync(new TokenRequestContext(MockScopes.Default));
}

Assert.AreEqual(1, callCount);
}

[Test]
public async Task VerifyExpiringTokenRefresh()
{
Expand Down Expand Up @@ -1035,6 +1013,11 @@ public void VerifyArcIdentitySourceFilePathValidation_FilePathInvalid()
[Test]
public async Task VerifyTokenExchangeMsiRequestMockAsync()
{
List<string> messages = new();
using AzureEventSourceListener listener = new AzureEventSourceListener(
(_, message) => messages.Add(message),
EventLevel.Informational);

var tenantId = "mock-tenant-id";
var clientId = "mock-client-id";
var authorityHostUrl = "https://mock.authority.com";
Expand Down Expand Up @@ -1081,6 +1064,7 @@ public async Task VerifyTokenExchangeMsiRequestMockAsync()
AccessToken actualToken = await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default));

Assert.AreEqual(ExpectedToken, actualToken.Token);
Assert.That(messages, Does.Contain(string.Format(AzureIdentityEventSource.ManagedIdentitySourceAttemptedMessage, "TokenExchangeManagedIdentitySource", true)));
}

private static IEnumerable<TestCaseData> ResourceAndClientIds()
Expand Down

0 comments on commit 02da266

Please sign in to comment.