Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Azure Managed Identity support - first version #25

Merged
merged 1 commit into from
May 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/DurableTask.SqlServer.AzureFunctions/ManagedIdentityOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.

namespace DurableTask.SqlServer.AzureFunctions
{
using System;
using Newtonsoft.Json;

class ManagedIdentityOptions
{
[JsonProperty("authorityHost")]
public Uri? AuthorityHost { get; set; }

[JsonProperty("tenantId")]
public string? TenantId { get; set; }

[JsonProperty("useAzureManagedIdentity")]
public bool UseAzureManagedIdentity { get; set; }
}
}
12 changes: 12 additions & 0 deletions src/DurableTask.SqlServer.AzureFunctions/SqlDurabilityOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class SqlDurabilityOptions
[JsonProperty("taskEventBatchSize")]
public int TaskEventBatchSize { get; set; } = 10;

public ManagedIdentityOptions? ManagedIdentityOptions { get; set; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add [JsonProperty("managedIdentityOptions")] to this new property.

Copy link
Member

@cgillum cgillum May 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately the Azure Functions configuration system doesn't like our use of nested options here. We'll need to flatten it in order for it to be recognized. My suggestion is that we add the three properties directly to the SqlDurabilityOptions class and rename them such that they all have the same prefix. For example, azureManagedIdentityEnabled, azureManagedIdentityAuthorityHost, and azureManagedIdentityTenantId. The ManagedIdentitySettings class, however, can remain as-is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, let's do it. Thanks for letting me know.


internal ILoggerFactory LoggerFactory { get; set; } = NullLoggerFactory.Instance;

internal SqlOrchestrationServiceSettings GetOrchestrationServiceSettings(
Expand Down Expand Up @@ -69,6 +71,16 @@ internal SqlOrchestrationServiceSettings GetOrchestrationServiceSettings(
settings.MaxActiveOrchestrations = extensionOptions.MaxConcurrentOrchestratorFunctions.Value;
}

if (this.ManagedIdentityOptions != null)
{
settings.ManagedIdentitySettings = new ManagedIdentitySettings
{
UseAzureManagedIdentity = this.ManagedIdentityOptions.UseAzureManagedIdentity,
AuthorityHost = this.ManagedIdentityOptions.AuthorityHost,
TenantId = this.ManagedIdentityOptions.TenantId
};
}

return settings;
}
}
Expand Down
1 change: 1 addition & 0 deletions src/DurableTask.SqlServer/DurableTask.SqlServer.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Azure.Identity" Version="1.4.0" />
<PackageReference Include="Microsoft.Azure.DurableTask.Core" Version="2.5.1" />
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="3.1.*" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="3.1.*" />
Expand Down
45 changes: 45 additions & 0 deletions src/DurableTask.SqlServer/ManagedIdentitySettings.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.

using System;
using Newtonsoft.Json;

namespace DurableTask.SqlServer
{
/// <summary>
/// Configuration options for Managed Identity.
/// </summary>
public class ManagedIdentitySettings
{
public const string Resource = "https://database.windows.net/";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my testing, I noticed that I needed to add ".default" to the end of this URL before I could successfully fetch a token from Azure AD. I think this might be a newer requirement in the more modern versions of Azure.Identity.

Suggested change
public const string Resource = "https://database.windows.net/";
public const string Resource = "https://database.windows.net/.default";


/// <summary>
/// Initializes a new instance of the <see cref="ManagedIdentitySettings"/> class.
/// </summary>
/// <param name="authorityHost">The host of the Azure Active Directory authority.</param>
/// <param name="tenantId">The tenant id of the user to authenticate.</param>
public ManagedIdentitySettings(Uri? authorityHost = null, string? tenantId = null)
{
this.AuthorityHost = authorityHost;
this.TenantId = tenantId;
}

/// <summary>
/// The host of the Azure Active Directory authority. The default is https://login.microsoftonline.com/.
/// </summary>
[JsonProperty("authorityHost")]
public Uri? AuthorityHost { get; set; }

/// <summary>
/// The tenant id of the user to authenticate.
/// </summary>
[JsonProperty("tenantId")]
public string? TenantId { get; set; }

/// <summary>
/// Use Azure Managed Identity to connect to SQL Server
/// </summary>
[JsonProperty("useAzureManagedIdentity")]
public bool UseAzureManagedIdentity { get; set; }
}
}
47 changes: 47 additions & 0 deletions src/DurableTask.SqlServer/SqlConnectionFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.

using System.Threading.Tasks;

using Azure.Core;
using Azure.Identity;

using Microsoft.Data.SqlClient;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for stylistic consistency, it would be great if you could move these using statements inside the namespace block.


namespace DurableTask.SqlServer
{
internal class SqlConnectionFactory
{
readonly string connectionString;
readonly ManagedIdentitySettings? managedIdentitySettings;

public SqlConnectionFactory(string connectionString, ManagedIdentitySettings? managedIdentitySettings = null)
{
this.connectionString = connectionString;
this.managedIdentitySettings = managedIdentitySettings;
}

public async Task<SqlConnection> CreateConnection()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this to CreateConnectionAsync()? It's not a big deal, but we try to follow this naming convention for methods that return tasks (which is the case now that we may fetch an Azure AD token).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I guess I forgot about this naming pattern :)

{
var connection = new SqlConnection(this.connectionString);
if (this.managedIdentitySettings != null && this.managedIdentitySettings.UseAzureManagedIdentity)
{
var azureCredentialOptions = new DefaultAzureCredentialOptions();
if (this.managedIdentitySettings.AuthorityHost != null)
{
azureCredentialOptions.AuthorityHost = this.managedIdentitySettings.AuthorityHost;
}
if (!string.IsNullOrEmpty(this.managedIdentitySettings.TenantId))
{
azureCredentialOptions.InteractiveBrowserTenantId = this.managedIdentitySettings.TenantId;
}
var azureCredential = new DefaultAzureCredential(azureCredentialOptions);
var requestContext = new TokenRequestContext(new string[] { ManagedIdentitySettings.Resource });
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to cache this requestContext object?

var accessToken = await azureCredential.GetTokenAsync(requestContext);
connection.AccessToken = accessToken.Token;
}

return connection;
}
}
}
10 changes: 5 additions & 5 deletions src/DurableTask.SqlServer/SqlDbManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ namespace DurableTask.SqlServer

class SqlDbManager
{
readonly SqlOrchestrationServiceSettings settings;
readonly SqlConnectionFactory connectionFactory;
readonly LogHelper traceHelper;

public SqlDbManager(SqlOrchestrationServiceSettings settings, LogHelper traceHelper)
public SqlDbManager(SqlConnectionFactory connectionFactory, LogHelper traceHelper)
{
this.settings = settings ?? throw new ArgumentNullException(nameof(settings));
this.connectionFactory = connectionFactory ?? throw new ArgumentNullException(nameof(connectionFactory));
this.traceHelper = traceHelper ?? throw new ArgumentNullException(nameof(traceHelper));
}

Expand Down Expand Up @@ -133,7 +133,7 @@ public async Task DeleteSchemaAsync()

async Task<DatabaseLock> AcquireDatabaseLockAsync()
{
SqlConnection connection = this.settings.CreateConnection();
SqlConnection connection = await this.connectionFactory.CreateConnection();
await connection.OpenAsync();

// It's possible that more than one worker may attempt to execute this creation logic at the same
Expand Down Expand Up @@ -187,7 +187,7 @@ async Task ExecuteSqlScriptAsync(string scriptName, DatabaseLock dbLock)
string schemaCommands = await GetScriptTextAsync(scriptName);

// Reference: https://stackoverflow.com/questions/650098/how-to-execute-an-sql-script-file-using-c-sharp
using SqlConnection scriptRunnerConnection = this.settings.CreateConnection();
using SqlConnection scriptRunnerConnection = await this.connectionFactory.CreateConnection();
var serverConnection = new ServerConnection(scriptRunnerConnection);

Stopwatch latencyStopwatch = Stopwatch.StartNew();
Expand Down
7 changes: 5 additions & 2 deletions src/DurableTask.SqlServer/SqlOrchestrationService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class SqlOrchestrationService : OrchestrationServiceBase
minimumInterval: TimeSpan.FromMilliseconds(50),
maximumInterval: TimeSpan.FromSeconds(3)); // TODO: Configurable

readonly SqlConnectionFactory connectionFactory;
readonly SqlOrchestrationServiceSettings settings;
readonly LogHelper traceHelper;
readonly SqlDbManager dbManager;
Expand All @@ -40,7 +41,9 @@ public SqlOrchestrationService(SqlOrchestrationServiceSettings? settings)
{
this.settings = ValidateSettings(settings) ?? throw new ArgumentNullException(nameof(settings));
this.traceHelper = new LogHelper(this.settings.LoggerFactory.CreateLogger("DurableTask.SqlServer"));
this.dbManager = new SqlDbManager(this.settings, this.traceHelper);
this.connectionFactory = new SqlConnectionFactory(
this.settings.TaskHubConnectionString, this.settings.ManagedIdentitySettings);
this.dbManager = new SqlDbManager(this.connectionFactory, this.traceHelper);
this.lockedByValue = $"{this.settings.AppName}|{Process.GetCurrentProcess().Id}";
this.userId = new SqlConnectionStringBuilder(this.settings.TaskHubConnectionString).UserID ?? string.Empty;
}
Expand Down Expand Up @@ -79,7 +82,7 @@ async Task<SqlConnection> GetAndOpenConnectionAsync(CancellationToken cancelToke
cancelToken = this.ShutdownToken;
}

SqlConnection connection = this.settings.CreateConnection();
SqlConnection connection = await this.connectionFactory.CreateConnection();
await connection.OpenAsync(cancelToken);
return connection;
}
Expand Down
9 changes: 5 additions & 4 deletions src/DurableTask.SqlServer/SqlOrchestrationServiceSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ public SqlOrchestrationServiceSettings(string connectionString, string? taskHubN
[JsonIgnore]
public ILoggerFactory LoggerFactory { get; set; } = NullLoggerFactory.Instance;

internal SqlConnection CreateConnection()
{
return new SqlConnection(this.TaskHubConnectionString);
}
/// <summary>
/// Gets or sets managed identity settings used to connect to a database.
/// </summary>
[JsonProperty("managedIdentitySettings")]
public ManagedIdentitySettings? ManagedIdentitySettings { get; set; }
}
}