Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert ExecuteNonQueryAsync to use async context object #1692

Merged
merged 2 commits into from
Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,68 @@ namespace Microsoft.Data.SqlClient
// CONSIDER creating your own Set method that calls the base Set rather than providing a parameterized ctor, it is friendlier to caching
// DO NOT use this class' state after Dispose has been called. It will not throw ObjectDisposedException but it will be a cleared object

internal abstract class AAsyncCallContext<TOwner, TTask> : IDisposable
internal abstract class AAsyncCallContext<TOwner, TTask, TDisposable> : AAsyncBaseCallContext<TOwner,TTask>
where TOwner : class
where TDisposable : IDisposable
{
protected TOwner _owner;
protected TaskCompletionSource<TTask> _source;
protected IDisposable _disposable;
protected TDisposable _disposable;

protected AAsyncCallContext()
{
}

protected AAsyncCallContext(TOwner owner, TaskCompletionSource<TTask> source, IDisposable disposable = null)
protected AAsyncCallContext(TOwner owner, TaskCompletionSource<TTask> source, TDisposable disposable = default)
{
Set(owner, source, disposable);
}

protected void Set(TOwner owner, TaskCompletionSource<TTask> source, IDisposable disposable = null)
protected void Set(TOwner owner, TaskCompletionSource<TTask> source, TDisposable disposable = default)
{
base.Set(owner, source);
_disposable = disposable;
}

protected override void DisposeCore()
{
TDisposable copyDisposable = _disposable;
_disposable = default;
copyDisposable?.Dispose();
}
}

internal abstract class AAsyncBaseCallContext<TOwner, TTask>
{
protected TOwner _owner;
protected TaskCompletionSource<TTask> _source;
protected bool _isDisposed;

protected AAsyncBaseCallContext()
{
}

protected void Set(TOwner owner, TaskCompletionSource<TTask> source)
{
_owner = owner ?? throw new ArgumentNullException(nameof(owner));
_source = source ?? throw new ArgumentNullException(nameof(source));
_disposable = disposable;
_isDisposed = false;
}

protected void ClearCore()
{
_source = null;
_owner = default;
IDisposable copyDisposable = _disposable;
_disposable = null;
copyDisposable?.Dispose();
try
{
DisposeCore();
}
finally
{
_isDisposed = true;
}
}

protected abstract void DisposeCore();

/// <summary>
/// override this method to cleanup instance data before ClearCore is called which will blank the base data
/// </summary>
Expand All @@ -65,16 +95,19 @@ protected virtual void AfterCleared(TOwner owner)

public void Dispose()
{
TOwner owner = _owner;
try
{
Clear();
}
finally
if (!_isDisposed)
{
ClearCore();
TOwner owner = _owner;
try
{
Clear();
}
finally
{
ClearCore();
}
AfterCleared(owner);
}
AfterCleared(owner);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ public sealed partial class SqlCommand : DbCommand, ICloneable
private static readonly Func<SqlCommand, CommandBehavior, AsyncCallback, object, int, bool, bool, IAsyncResult> s_beginExecuteXmlReaderInternal = BeginExecuteXmlReaderInternalCallback;
private static readonly Func<SqlCommand, CommandBehavior, AsyncCallback, object, int, bool, bool, IAsyncResult> s_beginExecuteNonQueryInternal = BeginExecuteNonQueryInternalCallback;

internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext<SqlCommand, SqlDataReader>
internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext<SqlCommand, SqlDataReader, CancellationTokenRegistration>
{
public Guid OperationID;
public CommandBehavior CommandBehavior;

public SqlCommand Command => _owner;
public TaskCompletionSource<SqlDataReader> TaskCompletionSource => _source;

public void Set(SqlCommand command, TaskCompletionSource<SqlDataReader> source, IDisposable disposable, CommandBehavior behavior, Guid operationID)
public void Set(SqlCommand command, TaskCompletionSource<SqlDataReader> source, CancellationTokenRegistration disposable, CommandBehavior behavior, Guid operationID)
{
base.Set(command, source, disposable);
CommandBehavior = behavior;
Expand All @@ -73,6 +73,31 @@ protected override void AfterCleared(SqlCommand owner)
}
}

internal sealed class ExecuteNonQueryAsyncCallContext : AAsyncCallContext<SqlCommand, int, CancellationTokenRegistration>
{
public Guid OperationID;

public SqlCommand Command => _owner;

public TaskCompletionSource<int> TaskCompletionSource => _source;

public void Set(SqlCommand command, TaskCompletionSource<int> source, CancellationTokenRegistration disposable, Guid operationID)
{
base.Set(command, source, disposable);
OperationID = operationID;
}

protected override void Clear()
{
OperationID = default;
}

protected override void AfterCleared(SqlCommand owner)
{

}
}

private CommandType _commandType;
private int? _commandTimeout;
private UpdateRowSource _updatedRowSource = UpdateRowSource.Both;
Expand Down Expand Up @@ -2540,23 +2565,36 @@ private Task<int> InternalExecuteNonQueryAsync(CancellationToken cancellationTok
}

Task<int> returnedTask = source.Task;
returnedTask = RegisterForConnectionCloseNotification(returnedTask);

ExecuteNonQueryAsyncCallContext context = new ExecuteNonQueryAsyncCallContext();
context.Set(this, source, registration, operationId);
try
{
returnedTask = RegisterForConnectionCloseNotification(returnedTask);

Task<int>.Factory.FromAsync(BeginExecuteNonQueryAsync, EndExecuteNonQueryAsync, null)
.ContinueWith((Task<int> task) =>
Task<int>.Factory.FromAsync(
static (AsyncCallback callback, object stateObject) => ((ExecuteNonQueryAsyncCallContext)stateObject).Command.BeginExecuteNonQueryAsync(callback, stateObject),
static (IAsyncResult result) => ((ExecuteNonQueryAsyncCallContext)result.AsyncState).Command.EndExecuteNonQueryAsync(result),
state: context
).ContinueWith(
static (Task<int> task, object state) =>
{
registration.Dispose();
ExecuteNonQueryAsyncCallContext context = (ExecuteNonQueryAsyncCallContext)state;

Guid operationId = context.OperationID;
SqlCommand command = context.Command;
TaskCompletionSource<int> source = context.TaskCompletionSource;

context.Dispose();
context = null;

if (task.IsFaulted)
{
Exception e = task.Exception.InnerException;
s_diagnosticListener.WriteCommandError(operationId, this, _transaction, e);
s_diagnosticListener.WriteCommandError(operationId, command, command._transaction, e);
source.SetException(e);
}
else
{
s_diagnosticListener.WriteCommandAfter(operationId, this, _transaction);
if (task.IsCanceled)
{
source.SetCanceled();
Expand All @@ -2565,15 +2603,18 @@ private Task<int> InternalExecuteNonQueryAsync(CancellationToken cancellationTok
{
source.SetResult(task.Result);
}
s_diagnosticListener.WriteCommandAfter(operationId, command, command._transaction);
}
},
TaskScheduler.Default
},
state: context,
scheduler: TaskScheduler.Default
);
}
catch (Exception e)
{
s_diagnosticListener.WriteCommandError(operationId, this, _transaction, e);
source.SetException(e);
context.Dispose();
}

return returnedTask;
Expand Down Expand Up @@ -2648,11 +2689,11 @@ private Task<SqlDataReader> InternalExecuteReaderAsync(CommandBehavior behavior,
}

Task<SqlDataReader> returnedTask = source.Task;
ExecuteReaderAsyncCallContext context = null;
try
{
returnedTask = RegisterForConnectionCloseNotification(returnedTask);

ExecuteReaderAsyncCallContext context = null;
if (_activeConnection?.InnerConnection is SqlInternalConnection sqlInternalConnection)
{
context = Interlocked.Exchange(ref sqlInternalConnection.CachedCommandExecuteReaderAsyncContext, null);
Expand Down Expand Up @@ -2680,6 +2721,7 @@ private Task<SqlDataReader> InternalExecuteReaderAsync(CommandBehavior behavior,
}

source.SetException(e);
context.Dispose();
}

return returnedTask;
Expand Down
Loading