Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically Create Database if Not Present #49

Merged
merged 6 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/DurableTask.SqlServer/Identifier.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.

namespace DurableTask.SqlServer
{
using System;
using System.Text;

static class Identifier
cgillum marked this conversation as resolved.
Show resolved Hide resolved
{
public static string Escape(string value)
{
if (value == null)
{
throw new ArgumentNullException(nameof(value));
}

if (value == "")
{
throw new ArgumentException("Value cannot be empty.", nameof(value));
}

StringBuilder builder = new StringBuilder();

builder.Append('[');
foreach (char c in value)
{
if (c == ']')
{
builder.Append(']');
cgillum marked this conversation as resolved.
Show resolved Hide resolved
}

builder.Append(c);
}
builder.Append(']');

return builder.ToString();
}
}
}
19 changes: 18 additions & 1 deletion src/DurableTask.SqlServer/LogHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public void AcquiredAppLock(int statusCode, Stopwatch latencyStopwatch)
var logEvent = new LogEvents.AcquiredAppLockEvent(
statusCode,
latencyStopwatch.ElapsedMilliseconds);

this.WriteLog(logEvent);
}

Expand Down Expand Up @@ -112,6 +112,23 @@ public void PurgedInstances(string userId, int purgedInstanceCount, Stopwatch la
this.WriteLog(logEvent);
}

public void CommandCompleted(string commandText, Stopwatch latencyStopwatch, int retryCount, string? instanceId)
{
var logEvent = new LogEvents.CommandCompletedEvent(
commandText,
latencyStopwatch.ElapsedMilliseconds,
retryCount,
instanceId);

this.WriteLog(logEvent);
}

public void CreatedDatabase(string databaseName)
{
var logEvent = new LogEvents.CreatedDatabaseEvent(databaseName);
this.WriteLog(logEvent);
}

void WriteLog(ILogEvent logEvent)
{
// LogDurableEvent is an extension method defined in DurableTask.Core
Expand Down
34 changes: 34 additions & 0 deletions src/DurableTask.SqlServer/Logging/DefaultEventSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -232,5 +232,39 @@ internal void PurgedInstances(
AppName,
ExtensionVersion);
}

[Event(EventIds.CommandCompleted, Level = EventLevel.Verbose)]
public void CommandCompleted(
string? InstanceId,
string CommandText,
long LatencyMs,
int RetryCount,
string AppName,
string ExtensionVersion)
{
// TODO: Switch to WriteEventCore for better performance
this.WriteEvent(
EventIds.CommandCompleted,
InstanceId ?? string.Empty,
CommandText,
LatencyMs,
RetryCount,
AppName,
ExtensionVersion);
}

[Event(EventIds.CreatedDatabase, Level = EventLevel.Informational)]
internal void CreatedDatabase(
string DatabaseName,
string AppName,
string ExtensionVersion)
{
// TODO: Use WriteEventCore for better performance
this.WriteEvent(
EventIds.CreatedDatabase,
DatabaseName,
AppName,
ExtensionVersion);
}
}
}
2 changes: 2 additions & 0 deletions src/DurableTask.SqlServer/Logging/EventIds.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ static class EventIds
public const int TransientDatabaseFailure = 308;
public const int ReplicaCountChangeRecommended = 309;
public const int PurgedInstances = 310;
public const int CommandCompleted = 311;
public const int CreatedDatabase = 312;
}
}
66 changes: 66 additions & 0 deletions src/DurableTask.SqlServer/Logging/LogEvents.cs
Original file line number Diff line number Diff line change
Expand Up @@ -415,5 +415,71 @@ void IEventSourceEvent.WriteEventSource() =>
DTUtils.AppName,
DTUtils.ExtensionVersionString);
}

internal class CommandCompletedEvent : StructuredLogEvent, IEventSourceEvent
{
public CommandCompletedEvent(string commandText, long latencyMs, int retryCount, string? instanceId)
{
this.CommandText = commandText;
this.LatencyMs = latencyMs;
this.RetryCount = retryCount;
this.InstanceId = instanceId;
}

[StructuredLogField]
public string CommandText { get; }

[StructuredLogField]
public long LatencyMs { get; }

[StructuredLogField]
public int RetryCount { get; }

[StructuredLogField]
public string? InstanceId { get; }

public override LogLevel Level => LogLevel.Debug;

public override EventId EventId => new EventId(
EventIds.CommandCompleted,
nameof(EventIds.CommandCompleted));

protected override string CreateLogMessage() =>
string.IsNullOrEmpty(this.InstanceId) ?
$"Executed SQL statement(s) '{this.CommandText}' in {this.LatencyMs}ms" :
$"{this.InstanceId}: Executed SQL statement(s) '{this.CommandText}' in {this.LatencyMs}ms";

void IEventSourceEvent.WriteEventSource() =>
DefaultEventSource.Log.CommandCompleted(
this.InstanceId,
this.CommandText,
this.LatencyMs,
this.RetryCount,
DTUtils.AppName,
DTUtils.ExtensionVersionString);
}

internal class CreatedDatabaseEvent : StructuredLogEvent, IEventSourceEvent
{
public CreatedDatabaseEvent(string databaseName) =>
this.DatabaseName = databaseName;

[StructuredLogField]
public string DatabaseName { get; }

public override EventId EventId => new EventId(
EventIds.CreatedDatabase,
nameof(EventIds.CreatedDatabase));

public override LogLevel Level => LogLevel.Information;

protected override string CreateLogMessage() => $"Created database '{this.DatabaseName}'.";

void IEventSourceEvent.WriteEventSource() =>
DefaultEventSource.Log.CreatedDatabase(
this.DatabaseName,
DTUtils.AppName,
DTUtils.ExtensionVersionString);
}
}
}
53 changes: 52 additions & 1 deletion src/DurableTask.SqlServer/SqlDbManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,28 @@ public SqlDbManager(SqlOrchestrationServiceSettings settings, LogHelper traceHel

public async Task CreateOrUpgradeSchemaAsync(bool recreateIfExists)
{
// Note that we may not be able to connect to the DB, let alone obtain the lock,
// if the database does not exist yet. So we obtain a connection to the 'master' database for now.
SqlConnection connection = this.settings.CreateConnection("master");
await connection.OpenAsync();

if (!await this.DoesDatabaseExistAsync(this.settings.DatabaseName, connection))
{
if (!this.settings.CreateDatabaseIfNotExists)
{
throw new InvalidOperationException($"Database '{this.settings.DatabaseName}' does not exist.");
}

await this.CreateDatabaseAsync(this.settings.DatabaseName, connection);
}

// Prevent other create or delete operations from executing at the same time.
await using DatabaseLock dbLock = await this.AcquireDatabaseLockAsync();
#if NETSTANDARD2_0
connection.ChangeDatabase(this.settings.DatabaseName);
#else
await connection.ChangeDatabaseAsync(this.settings.DatabaseName);
#endif
await using DatabaseLock dbLock = await this.AcquireDatabaseLockAsync(connection);

var currentSchemaVersion = new SemanticVersion(0, 0, 0);
if (recreateIfExists)
Expand Down Expand Up @@ -136,6 +156,11 @@ async Task<DatabaseLock> AcquireDatabaseLockAsync()
SqlConnection connection = this.settings.CreateConnection();
await connection.OpenAsync();

return await this.AcquireDatabaseLockAsync(connection);
}

async Task<DatabaseLock> AcquireDatabaseLockAsync(SqlConnection connection)
{
// It's possible that more than one worker may attempt to execute this creation logic at the same
// time. To avoid update conflicts, we use an app lock + a transaction to ensure only a single worker
// can perform an upgrade at a time. All other workers will wait for the first one to complete.
Expand Down Expand Up @@ -171,6 +196,32 @@ async Task<DatabaseLock> AcquireDatabaseLockAsync()
return new DatabaseLock(connection, lockTransaction);
}

async Task<bool> DoesDatabaseExistAsync(string databaseName, SqlConnection connection)
{
using SqlCommand command = connection.CreateCommand();
command.CommandText = "SELECT 1 FROM sys.databases WHERE name = @databaseName";
command.Parameters.AddWithValue("@databaseName", databaseName);

bool exists = (int?)await SqlUtils.ExecuteScalarAsync(command, this.traceHelper) == 1;
return exists;
}

async Task CreateDatabaseAsync(string databaseName, SqlConnection connection)
{
using SqlCommand command = connection.CreateCommand();
command.CommandText = $"CREATE DATABASE {Identifier.Escape(databaseName)} COLLATE Latin1_General_100_BIN2_UTF8";
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use Identifier.Escape(string) to handle possibly non-standard names as well as protect against any malicious queries while we're setting up.


try
{
await SqlUtils.ExecuteNonQueryAsync(command, this.traceHelper);
this.traceHelper.CreatedDatabase(databaseName);
cgillum marked this conversation as resolved.
Show resolved Hide resolved
}
catch (SqlException e) when (e.Number == 1801 /* Database already exists */)
{
// Ignore
}
}

async Task ExecuteSqlScriptAsync(string scriptName, DatabaseLock dbLock)
{
// We don't actually use the lock here, but want to make sure the caller is holding it.
Expand Down
40 changes: 40 additions & 0 deletions src/DurableTask.SqlServer/SqlOrchestrationServiceSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ public SqlOrchestrationServiceSettings(string connectionString, string? taskHubN
ApplicationName = this.TaskHubName,
};

if (string.IsNullOrEmpty(builder.InitialCatalog))
{
throw new ArgumentException("Database or Initial Catalog must be specified in the connection string.", nameof(connectionString));
}

this.DatabaseName = builder.InitialCatalog;
this.TaskHubConnectionString = builder.ToString();
}

Expand Down Expand Up @@ -79,6 +85,15 @@ public SqlOrchestrationServiceSettings(string connectionString, string? taskHubN
[JsonProperty("maxActiveOrchestrations")]
public int MaxActiveOrchestrations { get; set; } = Environment.ProcessorCount;

/// <summary>
/// Gets or sets a flag indicating whether the database should be automatically created if it does not exist.
/// </summary>
/// <remarks>
/// If <see langword="true"/>, the user requires the permission <c>CREATE DATABASE</c>.
/// </remarks>
[JsonProperty("createDatabaseIfNotExists")]
public bool CreateDatabaseIfNotExists { get; set; }
cgillum marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Gets a SQL connection string associated with the configured task hub.
/// </summary>
Expand All @@ -91,6 +106,31 @@ public SqlOrchestrationServiceSettings(string connectionString, string? taskHubN
[JsonIgnore]
public ILoggerFactory LoggerFactory { get; set; } = NullLoggerFactory.Instance;

/// <summary>
/// Gets or sets the name of the database that contains the instance store.
/// </summary>
/// <remarks>
/// This value is derived from the value of the <c>"Initial Catalog"</c> or <c>"Database"</c>
/// attribute in the <see cref="TaskHubConnectionString"/>.
/// </remarks>
[JsonIgnore]
public string DatabaseName { get; set; }

internal SqlConnection CreateConnection() => new SqlConnection(this.TaskHubConnectionString);

internal SqlConnection CreateConnection(string databaseName)
{
if (databaseName == this.DatabaseName)
cgillum marked this conversation as resolved.
Show resolved Hide resolved
{
return this.CreateConnection();
}

SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(this.TaskHubConnectionString)
{
InitialCatalog = databaseName
};

return new SqlConnection(builder.ToString());
}
}
}
25 changes: 23 additions & 2 deletions src/DurableTask.SqlServer/SqlUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public static HistoryEvent GetHistoryEvent(this DbDataReader reader, bool isOrch
int eventId = GetTaskId(reader);

HistoryEvent historyEvent;
switch(eventType)
switch (eventType)
{
case EventType.ContinueAsNew:
historyEvent = new ContinueAsNewEvent(eventId, GetPayloadText(reader));
Expand Down Expand Up @@ -396,6 +396,19 @@ public static Task<int> ExecuteNonQueryAsync(
cmd => cmd.ExecuteNonQueryAsync(cancellationToken));
}

public static Task<object> ExecuteScalarAsync(
DbCommand command,
LogHelper traceHelper,
string? instanceId = null,
CancellationToken cancellationToken = default)
{
return ExecuteSprocAndTraceAsync(
command,
traceHelper,
instanceId,
cmd => cmd.ExecuteScalarAsync(cancellationToken));
}

static async Task<T> ExecuteSprocAndTraceAsync<T>(
DbCommand command,
LogHelper traceHelper,
Expand All @@ -410,7 +423,15 @@ static async Task<T> ExecuteSprocAndTraceAsync<T>(
finally
{
context.LatencyStopwatch.Stop();
traceHelper.SprocCompleted(command.CommandText, context.LatencyStopwatch, context.RetryCount, instanceId);
switch (command.CommandType)
{
case CommandType.StoredProcedure:
traceHelper.SprocCompleted(command.CommandText, context.LatencyStopwatch, context.RetryCount, instanceId);
break;
default:
traceHelper.CommandCompleted(command.CommandText, context.LatencyStopwatch, context.RetryCount, instanceId);
break;
}
}
}

Expand Down
Loading