Skip to content

Commit

Permalink
ManagedIdentityCredential honors CancellationTokens (#47171)
Browse files Browse the repository at this point in the history
  • Loading branch information
christothes authored Nov 22, 2024
1 parent 4d42d7b commit fc1a231
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 4 deletions.
6 changes: 4 additions & 2 deletions sdk/identity/Azure.Identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
### Breaking Changes

### Bugs Fixed
- Fixed an issue where setting `DefaultAzureCredentialOptions.TenantId` twice throws an `InvalidOperationException`. ([#47035](https://github.com/Azure/azure-sdk-for-net/issues/47035))
- Fixed an issue where some credentials in `DefaultAzureCredential` would not fall through to the next credential in the chain under certain exception conditions.

- Fixed an issue where setting `DefaultAzureCredentialOptions.TenantId` twice throws an `InvalidOperationException` ([#47035](https://github.com/Azure/azure-sdk-for-net/issues/47035))
- Fixed an issue where `ManagedIdentityCredential` does not honor the `CancellationToken` passed to `GetToken` and `GetTokenAsync`. ([#47156](https://github.com/Azure/azure-sdk-for-net/issues/47156))
- Fixed an issue where some credentials in `DefaultAzureCredential` would not fall through to the next credential in the chain under certain exception conditions.

### Other Changes

Expand Down
4 changes: 2 additions & 2 deletions sdk/identity/Azure.Identity/src/MsalManagedIdentityClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ public virtual async ValueTask<AuthenticationResult> AcquireTokenForManagedIdent
}
#pragma warning disable AZC0102 // Do not use GetAwaiter().GetResult().
return async ?
await builder.ExecuteAsync().ConfigureAwait(false) :
builder.ExecuteAsync().GetAwaiter().GetResult();
await builder.ExecuteAsync(cancellationToken).ConfigureAwait(false) :
builder.ExecuteAsync(cancellationToken).GetAwaiter().GetResult();
#pragma warning restore AZC0102 // Do not use GetAwaiter().GetResult().
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.Pipeline;
Expand Down Expand Up @@ -174,6 +175,64 @@ public void ManagedIdentityCredentialUsesDefaultTimeoutAndRetries()
CollectionAssert.AreEqual(expectedTimeouts, networkTimeouts);
}

[Test]
public void ManagedIdentityCredentialRetryBehaviorIsOverriddenWithOptions()
{
int callCount = 0;
List<TimeSpan?> networkTimeouts = new();

var mockTransport = MockTransport.FromMessageCallback(msg =>
{
callCount++;
networkTimeouts.Add(msg.NetworkTimeout);
Assert.IsTrue(msg.Request.Headers.TryGetValue(ImdsManagedIdentitySource.metadataHeaderName, out _));
return CreateMockResponse(500, "Error").WithHeader("Content-Type", "application/json");
});

var options = new TokenCredentialOptions()
{
Transport = mockTransport,
RetryPolicy = new RetryPolicy(1, DelayStrategy.CreateFixedDelayStrategy(TimeSpan.Zero))
};
options.Retry.MaxDelay = TimeSpan.Zero;

var cred = new ManagedIdentityCredential(
"testCLientId", options);

Assert.ThrowsAsync<AuthenticationFailedException>(async () => await cred.GetTokenAsync(new(new[] { "test" })));

var expectedTimeouts = new TimeSpan?[] { null, null };
CollectionAssert.AreEqual(expectedTimeouts, networkTimeouts);
}

[Test]
public void ManagedIdentityCredentialRespectsCancellationToken()
{
int callCount = 0;

var mockTransport = MockTransport.FromMessageCallback(msg =>
{
Task.Delay(1000).GetAwaiter().GetResult();
callCount++;
return CreateMockResponse(500, "Error").WithHeader("Content-Type", "application/json");
});

var options = new TokenCredentialOptions() { Transport = mockTransport };
options.Retry.MaxDelay = TimeSpan.FromSeconds(1);

var cred = new ManagedIdentityCredential(
"testCLientId", options);

var cts = new CancellationTokenSource();
cts.CancelAfter(TimeSpan.Zero);
var ex = Assert.CatchAsync(async () => await cred.GetTokenAsync(new(new[] { "test" }), cts.Token));
Assert.IsTrue(ex is TaskCanceledException || ex is OperationCanceledException, "Expected TaskCanceledException or OperationCanceledException but got " + ex.GetType().ToString());

// Default number of retries is 5, so we should just ensure we have less than that.
// Timing on some platforms makes this test somewhat non-deterministic, so we just ensure we have less than 2 calls.
Assert.Less(callCount, 2);
}

private MockResponse CreateMockResponse(int responseCode, string token)
{
var response = new MockResponse(responseCode);
Expand Down

0 comments on commit fc1a231

Please sign in to comment.