Skip to content

Commit

Permalink
Initial MSI credential implementation (#6464)
Browse files Browse the repository at this point in the history
* initial MSI credential implementation

* removing uneeded TokenResponse class

* adding mock msi credential tests

* adding mock transport msi credential test

* updating System.Text.Json package reference accross all projects

* removing comments from errantly commented asserts

* upgrading System.Threading.Tasks.Extensions to match version in Azure.Core

* fixing datetime / datetime offset discrepency in SecretsTests.cs
  • Loading branch information
schaabs authored Jun 6, 2019
1 parent e7dc3b4 commit 8e5dca3
Show file tree
Hide file tree
Showing 20 changed files with 469 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="1.0.0-preview6.19259.10" />
<PackageReference Include="System.Text.Json" Version="4.6.0-preview6.19226.8" />
<PackageReference Include="System.Text.Json" Version="4.6.0-preview6.19259.10" />
</ItemGroup>

<!-- Import the Azure.Core project -->
Expand Down
2 changes: 1 addition & 1 deletion sdk/core/Azure.Core/tests/Azure.Core.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
<PackageReference Include="NUnit3TestAdapter" Version="3.10.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.8.0" />
<PackageReference Include="System.Memory" Version="4.5.2" />
<PackageReference Include="System.Text.Json" Version="4.6.0-preview6.19226.8" />
<PackageReference Include="System.Text.Json" Version="4.6.0-preview6.19259.10" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\src\Azure.Core.csproj">
Expand Down
8 changes: 6 additions & 2 deletions sdk/identity/Azure.Identity/src/Azure.Identity.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@

<ItemGroup>
<PackageReference Include="System.Memory" Version="4.5.2" />
<PackageReference Include="System.Text.Json" Version="4.6.0-preview6.19226.8" />
<PackageReference Include="System.Threading.Tasks.Extensions" Version="4.5.1" />
<PackageReference Include="System.Text.Json" Version="4.6.0-preview6.19259.10" />
<PackageReference Include="System.Threading.Tasks.Extensions" Version="4.5.2" />
</ItemGroup>

<ItemGroup>
<Compile Include="$(AzureCoreSharedSources)ArrayBufferWriter.cs" />
</ItemGroup>

<!-- Import the Azure.Base project -->
<Import Project="$(MSBuildThisFileDirectory)..\..\..\core\Azure.Core\src\Azure.Core.props" />
</Project>
3 changes: 2 additions & 1 deletion sdk/identity/Azure.Identity/src/AzureCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ public override string GetToken(string[] scopes, CancellationToken cancellationT

protected abstract AccessToken GetTokenCore(string[] scopes, CancellationToken cancellationToken);

internal IdentityClient Client => _client;
internal IdentityClient Client { get => _client; set => _client = value; }


public static TokenCredential Default { get; set; }

Expand Down
66 changes: 64 additions & 2 deletions sdk/identity/Azure.Identity/src/IdentityClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ internal class IdentityClient
{
private readonly IdentityClientOptions _options;
private readonly HttpPipeline _pipeline;
private readonly Uri ImdsEndptoint = new Uri("http://169.254.169.254/metadata/identity/oauth2/token");
private readonly string MsiApiVersion = "2018-02-01";

public IdentityClient(IdentityClientOptions options = null)
{
Expand All @@ -32,7 +34,7 @@ public IdentityClient(IdentityClientOptions options = null)
BufferResponsePolicy.Singleton);
}

public async Task<AccessToken> AuthenticateAsync(string tenantId, string clientId, string clientSecret, string[] scopes, CancellationToken cancellationToken = default)
public virtual async Task<AccessToken> AuthenticateAsync(string tenantId, string clientId, string clientSecret, string[] scopes, CancellationToken cancellationToken = default)
{
using (Request request = CreateClientSecretAuthRequest(tenantId, clientId, clientSecret, scopes))
{
Expand All @@ -49,7 +51,7 @@ public async Task<AccessToken> AuthenticateAsync(string tenantId, string clientI
}
}

public AccessToken Authenticate(string tenantId, string clientId, string clientSecret, string[] scopes, CancellationToken cancellationToken = default)
public virtual AccessToken Authenticate(string tenantId, string clientId, string clientSecret, string[] scopes, CancellationToken cancellationToken = default)
{
using (Request request = CreateClientSecretAuthRequest(tenantId, clientId, clientSecret, scopes))
{
Expand All @@ -66,6 +68,66 @@ public AccessToken Authenticate(string tenantId, string clientId, string clientS
}
}

public virtual async Task<AccessToken> AuthenticateManagedIdentityAsync(string[] scopes, string clientId = null, CancellationToken cancellationToken = default)
{
using (Request request = CreateManagedIdentityAuthRequest(scopes, clientId))
{
var response = await _pipeline.SendRequestAsync(request, cancellationToken).ConfigureAwait(false);

if (response.Status == 200 || response.Status == 201)
{
var result = await DeserializeAsync(response.ContentStream, cancellationToken).ConfigureAwait(false);

return new Response<AccessToken>(response, result);
}

throw response.CreateRequestFailedException();
}
}

public virtual AccessToken AuthenticateManagedIdentity(string[] scopes, string clientId = null, CancellationToken cancellationToken = default)
{
using (Request request = CreateManagedIdentityAuthRequest(scopes, clientId))
{
var response = _pipeline.SendRequest(request, cancellationToken);

if (response.Status == 200 || response.Status == 201)
{
var result = Deserialize(response.ContentStream);

return new Response<AccessToken>(response, result);
}

throw response.CreateRequestFailedException();
}
}

private Request CreateManagedIdentityAuthRequest(string[] scopes, string clientId = null)
{
// covert the scopes to a resource string
string resource = ScopeUtilities.ScopesToResource(scopes);

Request request = _pipeline.CreateRequest();

request.Method = HttpPipelineMethod.Get;

request.Headers.Add("Metadata", "true");

// TODO: support MSI for hosted services
request.UriBuilder.Uri = ImdsEndptoint;

request.UriBuilder.AppendQuery("api-version", MsiApiVersion);

request.UriBuilder.AppendQuery("resource", Uri.EscapeDataString(resource));

if (!string.IsNullOrEmpty(clientId))
{
request.UriBuilder.AppendQuery("client_id", Uri.EscapeDataString(clientId));
}

return request;
}

private Request CreateClientSecretAuthRequest(string tenantId, string clientId, string clientSecret, string[] scopes)
{
Request request = _pipeline.CreateRequest();
Expand Down
33 changes: 33 additions & 0 deletions sdk/identity/Azure.Identity/src/ManagedIdentityCredential.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Azure.Identity
{
public class ManagedIdentityCredential : AzureCredential
{
private string _clientId;

public ManagedIdentityCredential(string clientId = null, IdentityClientOptions options = null)
: base(options)
{
_clientId = clientId;
}

protected override async Task<AccessToken> GetTokenCoreAsync(string[] scopes, CancellationToken cancellationToken = default)
{
return await this.Client.AuthenticateManagedIdentityAsync(scopes, _clientId, cancellationToken).ConfigureAwait(false);
}

protected override AccessToken GetTokenCore(string[] scopes, CancellationToken cancellationToken = default)
{
return this.Client.AuthenticateManagedIdentity(scopes, _clientId, cancellationToken);
}
}
}
1 change: 1 addition & 0 deletions sdk/identity/Azure.Identity/src/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
using System.Runtime.CompilerServices;

[assembly:AzureSdkClientLibrary("identity")]
[assembly: InternalsVisibleTo("Azure.Identity.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100d15ddcb29688295338af4b7686603fe614abd555e09efba8fb88ee09e1f7b1ccaeed2e8f823fa9eef3fdd60217fc012ea67d2479751a0b8c087a4185541b851bd8b16f8d91b840e51b1cb0ba6fe647997e57429265e85ef62d565db50a69ae1647d54d7bd855e4db3d8a91510e5bcbd0edfbbecaa20a7bd9ae74593daa7b11b4")]
37 changes: 37 additions & 0 deletions sdk/identity/Azure.Identity/src/ScopeUtilities.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Azure.Identity
{
internal static class ScopeUtilities
{
private const string DefaultSuffix = "/.defualt";


public static string ScopesToResource(string[] scopes)
{
if (scopes == null) throw new ArgumentNullException(nameof(scopes));

if (scopes.Length != 1) throw new ArgumentException("To convert to a resource string the specified array must be exactly length 1", nameof(scopes));

if (!scopes[0].EndsWith(DefaultSuffix))
{
return scopes[0];
}

return scopes[0].Remove(scopes[0].LastIndexOf(DefaultSuffix));
}

public static string[] ResourceToScopes(string resource)
{
return new string[] { resource + "/.default" };
}

}
}
9 changes: 7 additions & 2 deletions sdk/identity/Azure.Identity/tests/Azure.Identity.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="nunit" Version="3.11.0" />
<PackageReference Include="NUnit3TestAdapter" Version="3.12.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" />
<PackageReference Include="xunit" Version="2.4.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.0" />
<PackageReference Include="Moq" Version="4.10.1" />
<PackageReference Include="BenchmarkDotNet" Version="0.11.5" />
</ItemGroup>

<Import Project="..\..\..\core\Azure.Core\tests\TestFramework.props" />


<ItemGroup>
<ProjectReference Include="..\src\Azure.Identity.csproj" />
</ItemGroup>
Expand Down
10 changes: 5 additions & 5 deletions sdk/identity/Azure.Identity/tests/AzureCredentialTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
using NUnit.Framework;

namespace Azure.Identity.Tests
{
Expand Down Expand Up @@ -38,7 +38,7 @@ protected override async Task<AccessToken> GetTokenCoreAsync(string[] scopes, Ca
}
}

[Fact]
[Test]
public async Task RefreshLogicDefaultAsync()
{
TimeSpan refreshBuffer = new IdentityClientOptions().RefreshBuffer;
Expand All @@ -56,9 +56,9 @@ public async Task RefreshLogicDefaultAsync()
await cred.GetTokenAsync(new string[] { "mockscope" });
}

Assert.Equal(2, refreshCred1.AuthCount);
Assert.Equal(2, refreshCred2.AuthCount);
Assert.Equal(1, notRefreshCred1.AuthCount);
Assert.AreEqual(2, refreshCred1.AuthCount);
Assert.AreEqual(2, refreshCred2.AuthCount);
Assert.AreEqual(1, notRefreshCred1.AuthCount);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
using System.Collections.Generic;
using System.Reflection;
using System.Text;
using Xunit;
using NUnit.Framework;

namespace Azure.Identity.Tests
{
Expand All @@ -15,16 +15,11 @@ public static TokenCredential _credential(this EnvironmentCredential provider)
}
}

[CollectionDefinition("EnvironmentTests", DisableParallelization = true)]
public class EnvironmentTestsCollection
{
}


[Collection("EnvironmentTests")]
public class EnvironmentCredentialProviderTests
{
[Fact]
[NonParallelizable]
[Test]
public void CredentialConstruction()
{
string clientIdBackup = Environment.GetEnvironmentVariable("AZURE_CLIENT_ID");
Expand All @@ -45,11 +40,11 @@ public void CredentialConstruction()

Assert.NotNull(cred);

Assert.Equal("mockclientid", cred.ClientId);
Assert.AreEqual("mockclientid", cred.ClientId);

Assert.Equal("mocktenantid", cred.TenantId);
Assert.AreEqual("mocktenantid", cred.TenantId);

Assert.Equal("mockclientsecret", cred.ClientSecret);
Assert.AreEqual("mockclientsecret", cred.ClientSecret);
}
finally
{
Expand Down
Loading

0 comments on commit 8e5dca3

Please sign in to comment.