Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically Create Database if Not Present #49

Merged
merged 6 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions global.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
wsugarman marked this conversation as resolved.
Show resolved Hide resolved
"sdk": {
"version": "3.1.403"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class SqlDurabilityOptions
[JsonProperty("taskEventBatchSize")]
public int TaskEventBatchSize { get; set; } = 10;

[JsonProperty("createDatabaseIfNotExists")]
public bool CreateDatabaseIfNotExists { get; set; }

internal ILoggerFactory LoggerFactory { get; set; } = NullLoggerFactory.Instance;

internal SqlOrchestrationServiceSettings GetOrchestrationServiceSettings(
Expand Down Expand Up @@ -54,9 +57,10 @@ internal SqlOrchestrationServiceSettings GetOrchestrationServiceSettings(

var settings = new SqlOrchestrationServiceSettings(connectionString, this.TaskHubName)
{
CreateDatabaseIfNotExists = this.CreateDatabaseIfNotExists,
LoggerFactory = this.LoggerFactory,
WorkItemLockTimeout = this.TaskEventLockTimeout,
WorkItemBatchSize = this.TaskEventBatchSize,
WorkItemLockTimeout = this.TaskEventLockTimeout,
};

if (extensionOptions.MaxConcurrentActivityFunctions.HasValue)
Expand Down
40 changes: 40 additions & 0 deletions src/DurableTask.SqlServer/Identifier.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.

namespace DurableTask.SqlServer
{
using System;
using System.Text;

static class Identifier
cgillum marked this conversation as resolved.
Show resolved Hide resolved
{
public static string Escape(string value)
{
if (value == null)
{
throw new ArgumentNullException(nameof(value));
}

if (value == "")
{
throw new ArgumentException("Value cannot be empty.", nameof(value));
}

StringBuilder builder = new StringBuilder();

builder.Append('[');
foreach (char c in value)
{
if (c == ']')
{
builder.Append(']');
cgillum marked this conversation as resolved.
Show resolved Hide resolved
}

builder.Append(c);
}
builder.Append(']');

return builder.ToString();
}
}
}
19 changes: 18 additions & 1 deletion src/DurableTask.SqlServer/LogHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public void AcquiredAppLock(int statusCode, Stopwatch latencyStopwatch)
var logEvent = new LogEvents.AcquiredAppLockEvent(
statusCode,
latencyStopwatch.ElapsedMilliseconds);

this.WriteLog(logEvent);
}

Expand Down Expand Up @@ -112,6 +112,23 @@ public void PurgedInstances(string userId, int purgedInstanceCount, Stopwatch la
this.WriteLog(logEvent);
}

public void CommandCompleted(string commandText, Stopwatch latencyStopwatch, int retryCount, string? instanceId)
{
var logEvent = new LogEvents.CommandCompletedEvent(
commandText,
latencyStopwatch.ElapsedMilliseconds,
retryCount,
instanceId);

this.WriteLog(logEvent);
}

public void CreatedDatabase(string databaseName)
{
var logEvent = new LogEvents.CreatedDatabaseEvent(databaseName);
this.WriteLog(logEvent);
}

void WriteLog(ILogEvent logEvent)
{
// LogDurableEvent is an extension method defined in DurableTask.Core
Expand Down
34 changes: 34 additions & 0 deletions src/DurableTask.SqlServer/Logging/DefaultEventSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -232,5 +232,39 @@ internal void PurgedInstances(
AppName,
ExtensionVersion);
}

[Event(EventIds.CommandCompleted, Level = EventLevel.Verbose)]
public void CommandCompleted(
string? InstanceId,
string CommandText,
long LatencyMs,
int RetryCount,
string AppName,
string ExtensionVersion)
{
// TODO: Switch to WriteEventCore for better performance
this.WriteEvent(
EventIds.CommandCompleted,
InstanceId ?? string.Empty,
CommandText,
LatencyMs,
RetryCount,
AppName,
ExtensionVersion);
}

[Event(EventIds.CreatedDatabase, Level = EventLevel.Informational)]
internal void CreatedDatabase(
string DatabaseName,
string AppName,
string ExtensionVersion)
{
// TODO: Use WriteEventCore for better performance
this.WriteEvent(
EventIds.CreatedDatabase,
DatabaseName,
AppName,
ExtensionVersion);
}
}
}
2 changes: 2 additions & 0 deletions src/DurableTask.SqlServer/Logging/EventIds.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ static class EventIds
public const int TransientDatabaseFailure = 308;
public const int ReplicaCountChangeRecommended = 309;
public const int PurgedInstances = 310;
public const int CommandCompleted = 311;
public const int CreatedDatabase = 312;
}
}
66 changes: 66 additions & 0 deletions src/DurableTask.SqlServer/Logging/LogEvents.cs
Original file line number Diff line number Diff line change
Expand Up @@ -415,5 +415,71 @@ void IEventSourceEvent.WriteEventSource() =>
DTUtils.AppName,
DTUtils.ExtensionVersionString);
}

internal class CommandCompletedEvent : StructuredLogEvent, IEventSourceEvent
{
public CommandCompletedEvent(string commandText, long latencyMs, int retryCount, string? instanceId)
{
this.CommandText = commandText;
this.LatencyMs = latencyMs;
this.RetryCount = retryCount;
this.InstanceId = instanceId;
}

[StructuredLogField]
public string CommandText { get; }

[StructuredLogField]
public long LatencyMs { get; }

[StructuredLogField]
public int RetryCount { get; }

[StructuredLogField]
public string? InstanceId { get; }

public override LogLevel Level => LogLevel.Debug;

public override EventId EventId => new EventId(
EventIds.CommandCompleted,
nameof(EventIds.CommandCompleted));

protected override string CreateLogMessage() =>
string.IsNullOrEmpty(this.InstanceId) ?
$"Executed SQL statement(s) '{this.CommandText}' in {this.LatencyMs}ms" :
$"{this.InstanceId}: Executed SQL statement(s) '{this.CommandText}' in {this.LatencyMs}ms";

void IEventSourceEvent.WriteEventSource() =>
DefaultEventSource.Log.CommandCompleted(
this.InstanceId,
this.CommandText,
this.LatencyMs,
this.RetryCount,
DTUtils.AppName,
DTUtils.ExtensionVersionString);
}

internal class CreatedDatabaseEvent : StructuredLogEvent, IEventSourceEvent
{
public CreatedDatabaseEvent(string databaseName) =>
this.DatabaseName = databaseName;

[StructuredLogField]
public string DatabaseName { get; }

public override EventId EventId => new EventId(
EventIds.CreatedDatabase,
nameof(EventIds.CreatedDatabase));

public override LogLevel Level => LogLevel.Information;

protected override string CreateLogMessage() => $"Created database '{this.DatabaseName}'.";

void IEventSourceEvent.WriteEventSource() =>
DefaultEventSource.Log.CreatedDatabase(
this.DatabaseName,
DTUtils.AppName,
DTUtils.ExtensionVersionString);
}
}
}
50 changes: 48 additions & 2 deletions src/DurableTask.SqlServer/SqlDbManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public SqlDbManager(SqlOrchestrationServiceSettings settings, LogHelper traceHel
public async Task CreateOrUpgradeSchemaAsync(bool recreateIfExists)
{
// Prevent other create or delete operations from executing at the same time.
await using DatabaseLock dbLock = await this.AcquireDatabaseLockAsync();
await using DatabaseLock dbLock = await this.AcquireDatabaseLockAsync(this.settings.CreateDatabaseIfNotExists);

var currentSchemaVersion = new SemanticVersion(0, 0, 0);
if (recreateIfExists)
Expand Down Expand Up @@ -131,8 +131,13 @@ public async Task DeleteSchemaAsync()

Task DropSchemaAsync(DatabaseLock dbLock) => this.ExecuteSqlScriptAsync("drop-schema.sql", dbLock);

async Task<DatabaseLock> AcquireDatabaseLockAsync()
async Task<DatabaseLock> AcquireDatabaseLockAsync(bool createDatabaseIfNotExists = false)
{
if (createDatabaseIfNotExists)
{
await this.EnsureDatabaseExistsAsync();
}

SqlConnection connection = this.settings.CreateConnection();
await connection.OpenAsync();

Expand Down Expand Up @@ -171,6 +176,47 @@ async Task<DatabaseLock> AcquireDatabaseLockAsync()
return new DatabaseLock(connection, lockTransaction);
}

async Task EnsureDatabaseExistsAsync()
{
// Note that we may not be able to connect to the DB, let alone obtain the lock,
// if the database does not exist yet. So we obtain a connection to the 'master' database for now.
using SqlConnection connection = this.settings.CreateConnection("master");
await connection.OpenAsync();

if (!await this.DoesDatabaseExistAsync(this.settings.DatabaseName, connection))
{
await this.CreateDatabaseAsync(this.settings.DatabaseName, connection);
}
}

async Task<bool> DoesDatabaseExistAsync(string databaseName, SqlConnection connection)
{
using SqlCommand command = connection.CreateCommand();
command.CommandText = "SELECT 1 FROM sys.databases WHERE name = @databaseName";
command.Parameters.AddWithValue("@databaseName", databaseName);

bool exists = (int?)await SqlUtils.ExecuteScalarAsync(command, this.traceHelper) == 1;
return exists;
}

async Task<bool> CreateDatabaseAsync(string databaseName, SqlConnection connection)
{
using SqlCommand command = connection.CreateCommand();
command.CommandText = $"CREATE DATABASE {Identifier.Escape(databaseName)} COLLATE Latin1_General_100_BIN2_UTF8";
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use Identifier.Escape(string) to handle possibly non-standard names as well as protect against any malicious queries while we're setting up.


try
{
await SqlUtils.ExecuteNonQueryAsync(command, this.traceHelper);
this.traceHelper.CreatedDatabase(databaseName);
cgillum marked this conversation as resolved.
Show resolved Hide resolved
return true;
}
catch (SqlException e) when (e.Number == 1801 /* Database already exists */)
{
// Ignore
return false;
}
}

async Task ExecuteSqlScriptAsync(string scriptName, DatabaseLock dbLock)
{
// We don't actually use the lock here, but want to make sure the caller is holding it.
Expand Down
40 changes: 40 additions & 0 deletions src/DurableTask.SqlServer/SqlOrchestrationServiceSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ public SqlOrchestrationServiceSettings(string connectionString, string? taskHubN
ApplicationName = this.TaskHubName,
};

if (string.IsNullOrEmpty(builder.InitialCatalog))
{
throw new ArgumentException("Database or Initial Catalog must be specified in the connection string.", nameof(connectionString));
}

this.DatabaseName = builder.InitialCatalog;
this.TaskHubConnectionString = builder.ToString();
}

Expand Down Expand Up @@ -79,6 +85,15 @@ public SqlOrchestrationServiceSettings(string connectionString, string? taskHubN
[JsonProperty("maxActiveOrchestrations")]
public int MaxActiveOrchestrations { get; set; } = Environment.ProcessorCount;

/// <summary>
/// Gets or sets a flag indicating whether the database should be automatically created if it does not exist.
/// </summary>
/// <remarks>
/// If <see langword="true"/>, the user requires the permission <c>CREATE DATABASE</c>.
/// </remarks>
[JsonProperty("createDatabaseIfNotExists")]
public bool CreateDatabaseIfNotExists { get; set; }
cgillum marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Gets a SQL connection string associated with the configured task hub.
/// </summary>
Expand All @@ -91,6 +106,31 @@ public SqlOrchestrationServiceSettings(string connectionString, string? taskHubN
[JsonIgnore]
public ILoggerFactory LoggerFactory { get; set; } = NullLoggerFactory.Instance;

/// <summary>
/// Gets or sets the name of the database that contains the instance store.
/// </summary>
/// <remarks>
/// This value is derived from the value of the <c>"Initial Catalog"</c> or <c>"Database"</c>
/// attribute in the <see cref="TaskHubConnectionString"/>.
/// </remarks>
[JsonIgnore]
public string DatabaseName { get; set; }

internal SqlConnection CreateConnection() => new SqlConnection(this.TaskHubConnectionString);

internal SqlConnection CreateConnection(string databaseName)
{
if (databaseName == this.DatabaseName)
cgillum marked this conversation as resolved.
Show resolved Hide resolved
{
return this.CreateConnection();
}

SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(this.TaskHubConnectionString)
{
InitialCatalog = databaseName
};

return new SqlConnection(builder.ToString());
}
}
}
Loading