diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs index 168584e9d4..774ee07582 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -80,7 +80,8 @@ public sealed class SqlCommand : DbCommand, ICloneable /// Internal flag for testing purposes that forces all queries to internally end async calls. /// private static bool _forceInternalEndQuery = false; -#endif +#endif + internal static readonly Action s_cancelIgnoreFailure = CancelIgnoreFailureCallback; // devnote: Prepare // Against 7.0 Server (Sphinx) a prepare/unprepare requires an extra roundtrip to the server. @@ -914,6 +915,12 @@ protected override DbParameterCollection DbParameterCollection return Parameters; } } + + internal static void CancelIgnoreFailureCallback(object state) + { + SqlCommand command = (SqlCommand)state; + command.CancelIgnoreFailure(); + } internal void CancelIgnoreFailure() { @@ -2972,7 +2979,7 @@ private Task InternalExecuteNonQueryAsync(CancellationToken cancellationTok source.SetCanceled(); return source.Task; } - registration = cancellationToken.Register(CancelIgnoreFailure); + registration = cancellationToken.Register(s_cancelIgnoreFailure, this); } Task returnedTask = source.Task; @@ -3069,7 +3076,7 @@ private Task InternalExecuteReaderAsync(CommandBehavior behavior, source.SetCanceled(); return source.Task; } - registration = cancellationToken.Register(CancelIgnoreFailure); + registration = cancellationToken.Register(s_cancelIgnoreFailure, this); } Task returnedTask = source.Task; @@ -3215,7 +3222,7 @@ private Task InternalExecuteXmlReaderAsync(CancellationToken cancella source.SetCanceled(); return source.Task; } - registration = cancellationToken.Register(CancelIgnoreFailure); + registration = cancellationToken.Register(s_cancelIgnoreFailure, this); } Task returnedTask = source.Task; diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs index 19697a5002..9605e27035 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -102,6 +102,9 @@ internal class SharedState private SqlSequentialStream _currentStream; private SqlSequentialTextReader _currentTextReader; + private IsDBNullAsyncCallContext _cachedIsDBNullContext; + private ReadAsyncCallContext _cachedReadAsyncContext; + internal SqlDataReader(SqlCommand command, CommandBehavior behavior) { SqlConnection.VerifyExecutePermission(); @@ -2168,45 +2171,45 @@ internal bool TryGetBytesInternalSequential(int i, byte[] buffer, int index, int { tdsReliabilitySection.Start(); #endif //DEBUG - if ((_sharedState._columnDataBytesRemaining == 0) || (length == 0)) - { - // No data left or nothing requested, return 0 - bytesRead = 0; - return true; - } - else + if ((_sharedState._columnDataBytesRemaining == 0) || (length == 0)) + { + // No data left or nothing requested, return 0 + bytesRead = 0; + return true; + } + else + { + // if plp columns, do partial reads. Don't read the entire value in one shot. + if (_metaData[i].metaType.IsPlp) { - // if plp columns, do partial reads. Don't read the entire value in one shot. - if (_metaData[i].metaType.IsPlp) + // Read in data + bool result = _stateObj.TryReadPlpBytes(ref buffer, index, length, out bytesRead); + _columnDataBytesRead += bytesRead; + if (!result) { - // Read in data - bool result = _stateObj.TryReadPlpBytes(ref buffer, index, length, out bytesRead); - _columnDataBytesRead += bytesRead; - if (!result) - { - return false; - } - - // Query for number of bytes left - ulong left; - if (!_parser.TryPlpBytesLeft(_stateObj, out left)) - { - _sharedState._columnDataBytesRemaining = -1; - return false; - } - _sharedState._columnDataBytesRemaining = (long)left; - return true; + return false; } - else + + // Query for number of bytes left + ulong left; + if (!_parser.TryPlpBytesLeft(_stateObj, out left)) { - // Read data (not exceeding the total amount of data available) - int bytesToRead = (int)Math.Min((long)length, _sharedState._columnDataBytesRemaining); - bool result = _stateObj.TryReadByteArray(buffer, index, bytesToRead, out bytesRead); - _columnDataBytesRead += bytesRead; - _sharedState._columnDataBytesRemaining -= bytesRead; - return result; + _sharedState._columnDataBytesRemaining = -1; + return false; } + _sharedState._columnDataBytesRemaining = (long)left; + return true; + } + else + { + // Read data (not exceeding the total amount of data available) + int bytesToRead = (int)Math.Min((long)length, _sharedState._columnDataBytesRemaining); + bool result = _stateObj.TryReadByteArray(buffer, index, bytesToRead, out bytesRead); + _columnDataBytesRead += bytesRead; + _sharedState._columnDataBytesRemaining -= bytesRead; + return result; } + } #if DEBUG } finally @@ -4155,7 +4158,7 @@ private bool TryReadColumnHeader(int i) { tdsReliabilitySection.Start(); #endif //DEBUG - return TryReadColumnInternal(i, readHeaderOnly: true); + return TryReadColumnInternal(i, readHeaderOnly: true); #if DEBUG } finally @@ -4821,7 +4824,7 @@ public override Task NextResultAsync(CancellationToken cancellationToken) source.SetCanceled(); return source.Task; } - registration = cancellationToken.Register(_command.CancelIgnoreFailure); + registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); } Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); @@ -4839,37 +4842,33 @@ public override Task NextResultAsync(CancellationToken cancellationToken) return source.Task; } - PrepareAsyncInvocation(useSnapshot: true); - - Func> moreFunc = null; - - moreFunc = (t) => - { - if (t != null) - { - SqlClientEventSource.Log.TryTraceEvent(" attempt retry {0}", ObjectID); - PrepareForAsyncContinuation(); - } - - bool more; - if (TryNextResult(out more)) - { - // completed - return more ? ADP.TrueTask : ADP.FalseTask; - } + return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registration)); + } + } - return ContinueRetryable(moreFunc); - }; + private static Task NextResultAsyncExecute(Task task, object state) + { + HasNextResultAsyncCallContext context = (HasNextResultAsyncCallContext)state; + if (task != null) + { + SqlClientEventSource.Log.TryTraceEvent(" attempt retry {0}", context._reader.ObjectID); + context._reader.PrepareForAsyncContinuation(); + } - return InvokeRetryable(moreFunc, source, registration); + if (context._reader.TryNextResult(out bool more)) + { + // completed + return more ? ADP.TrueTask : ADP.FalseTask; } + + return context._reader.ExecuteAsyncCall(context); } // NOTE: This will return null if it completed sequentially // If this returns null, then you can use bytesRead to see how many bytes were read - otherwise bytesRead should be ignored - internal Task GetBytesAsync(int i, byte[] buffer, int index, int length, int timeout, CancellationToken cancellationToken, out int bytesRead) + internal Task GetBytesAsync(int columnIndex, byte[] buffer, int index, int length, int timeout, CancellationToken cancellationToken, out int bytesRead) { - AssertReaderState(requireData: true, permitAsync: true, columnIndex: i, enforceSequentialAccess: true); + AssertReaderState(requireData: true, permitAsync: true, columnIndex: columnIndex, enforceSequentialAccess: true); Debug.Assert(IsCommandBehavior(CommandBehavior.SequentialAccess)); bytesRead = 0; @@ -4895,6 +4894,16 @@ internal Task GetBytesAsync(int i, byte[] buffer, int index, int length, in } } + var context = new GetBytesAsyncCallContext(this) + { + columnIndex = columnIndex, + buffer = buffer, + index = index, + length = length, + timeout = timeout, + cancellationToken = cancellationToken, + }; + // Check if we need to skip columns Debug.Assert(_sharedState._nextColumnDataToRead <= _lastColumnWithDataChunkRead, "Non sequential access"); if ((_sharedState._nextColumnHeaderToRead <= _lastColumnWithDataChunkRead) || (_sharedState._nextColumnDataToRead < _lastColumnWithDataChunkRead)) @@ -4907,10 +4916,6 @@ internal Task GetBytesAsync(int i, byte[] buffer, int index, int length, in return source.Task; } - PrepareAsyncInvocation(useSnapshot: true); - - Func> moreFunc = null; - // Timeout CancellationToken timeoutToken = CancellationToken.None; CancellationTokenSource timeoutCancellationSource = null; @@ -4921,66 +4926,25 @@ internal Task GetBytesAsync(int i, byte[] buffer, int index, int length, in timeoutToken = timeoutCancellationSource.Token; } - moreFunc = (t) => - { - if (t != null) - { - SqlClientEventSource.Log.TryTraceEvent(" attempt retry {0}", ObjectID); - PrepareForAsyncContinuation(); - } - - // Prepare for stateObj timeout - SetTimeout(_defaultTimeoutMilliseconds); - - if (TryReadColumnHeader(i)) - { - // Only once we have read upto where we need to be can we check the cancellation tokens (otherwise we will be in an unknown state) + context._disposable = timeoutCancellationSource; + context.timeoutToken = timeoutToken; + context._source = source; - if (cancellationToken.IsCancellationRequested) - { - // User requested cancellation - return ADP.CreatedTaskWithCancellation(); - } - else if (timeoutToken.IsCancellationRequested) - { - // Timeout - return ADP.CreatedTaskWithException(ADP.ExceptionWithStackTrace(ADP.IO(SQLMessage.Timeout()))); - } - else - { - // Upto the correct column - continue to read - SwitchToAsyncWithoutSnapshot(); - int totalBytesRead; - var readTask = GetBytesAsyncReadDataStage(i, buffer, index, length, timeout, true, cancellationToken, timeoutToken, out totalBytesRead); - if (readTask == null) - { - // Completed synchronously - return Task.FromResult(totalBytesRead); - } - else - { - return readTask; - } - } - } - else - { - return ContinueRetryable(moreFunc); - } - }; + PrepareAsyncInvocation(useSnapshot: true); - return InvokeRetryable(moreFunc, source, timeoutCancellationSource); + return InvokeAsyncCall(context); } else { // We're already at the correct column, just read the data + context.mode = GetBytesAsyncCallContext.OperationMode.Read; // Switch to async PrepareAsyncInvocation(useSnapshot: false); try { - return GetBytesAsyncReadDataStage(i, buffer, index, length, timeout, false, cancellationToken, CancellationToken.None, out bytesRead); + return GetBytesAsyncReadDataStage(context, false, out bytesRead); } catch { @@ -4990,17 +4954,126 @@ internal Task GetBytesAsync(int i, byte[] buffer, int index, int length, in } } - private Task GetBytesAsyncReadDataStage(int i, byte[] buffer, int index, int length, int timeout, bool isContinuation, CancellationToken cancellationToken, CancellationToken timeoutToken, out int bytesRead) + private static Task GetBytesAsyncSeekExecute(Task task, object state) + { + GetBytesAsyncCallContext context = (GetBytesAsyncCallContext)state; + SqlDataReader reader = context._reader; + + Debug.Assert(context.mode == GetBytesAsyncCallContext.OperationMode.Seek, "context.mode must be Seek to check if seeking can resume"); + + if (task != null) + { + reader.PrepareForAsyncContinuation(); + } + // Prepare for stateObj timeout + reader.SetTimeout(reader._defaultTimeoutMilliseconds); + + if (reader.TryReadColumnHeader(context.columnIndex)) + { + // Only once we have read upto where we need to be can we check the cancellation tokens (otherwise we will be in an unknown state) + + if (context.cancellationToken.IsCancellationRequested) + { + // User requested cancellation + return Task.FromCanceled(context.cancellationToken); + } + else if (context.timeoutToken.IsCancellationRequested) + { + // Timeout + return ADP.CreatedTaskWithException(ADP.ExceptionWithStackTrace(ADP.IO(SQLMessage.Timeout()))); + } + else + { + // Upto the correct column - continue to read + context.mode = GetBytesAsyncCallContext.OperationMode.Read; + reader.SwitchToAsyncWithoutSnapshot(); + int totalBytesRead; + var readTask = reader.GetBytesAsyncReadDataStage(context, true, out totalBytesRead); + if (readTask == null) + { + // Completed synchronously + return Task.FromResult(totalBytesRead); + } + else + { + return readTask; + } + } + } + else + { + return reader.ExecuteAsyncCall(context); + } + } + + private static Task GetBytesAsyncReadExecute(Task task, object state) + { + var context = (GetBytesAsyncCallContext)state; + SqlDataReader reader = context._reader; + + Debug.Assert(context.mode == GetBytesAsyncCallContext.OperationMode.Read, "context.mode must be Read to check if read can resume"); + + reader.PrepareForAsyncContinuation(); + + if (context.cancellationToken.IsCancellationRequested) + { + // User requested cancellation + return Task.FromCanceled(context.cancellationToken); + } + else if (context.timeoutToken.IsCancellationRequested) + { + // Timeout + return Task.FromException(ADP.ExceptionWithStackTrace(ADP.IO(SQLMessage.Timeout()))); + } + else + { + // Prepare for stateObj timeout + reader.SetTimeout(reader._defaultTimeoutMilliseconds); + + int bytesReadThisIteration; + bool result = reader.TryGetBytesInternalSequential( + context.columnIndex, + context.buffer, + context.index + context.totalBytesRead, + context.length - context.totalBytesRead, + out bytesReadThisIteration + ); + context.totalBytesRead += bytesReadThisIteration; + Debug.Assert(context.totalBytesRead <= context.length, "Read more bytes than required"); + + if (result) + { + return Task.FromResult(context.totalBytesRead); + } + else + { + return reader.ExecuteAsyncCall(context); + } + } + } + + private Task GetBytesAsyncReadDataStage(GetBytesAsyncCallContext context, bool isContinuation, out int bytesRead) { - _lastColumnWithDataChunkRead = i; + Debug.Assert(context.mode == GetBytesAsyncCallContext.OperationMode.Read, "context.Mode must be Read to read data"); + + _lastColumnWithDataChunkRead = context.columnIndex; TaskCompletionSource source = null; - CancellationTokenSource timeoutCancellationSource = null; // Prepare for stateObj timeout SetTimeout(_defaultTimeoutMilliseconds); // Try to read without any continuations (all the data may already be in the stateObj's buffer) - if (!TryGetBytesInternalSequential(i, buffer, index, length, out bytesRead)) + bool filledBuffer = context._reader.TryGetBytesInternalSequential( + context.columnIndex, + context.buffer, + context.index + context.totalBytesRead, + context.length - context.totalBytesRead, + out bytesRead + ); + context.totalBytesRead += bytesRead; + Debug.Assert(context.totalBytesRead <= context.length, "Read more bytes than required"); + + if (!filledBuffer) { // This will be the 'state' for the callback int totalBytesRead = bytesRead; @@ -5008,6 +5081,7 @@ private Task GetBytesAsyncReadDataStage(int i, byte[] buffer, int index, in if (!isContinuation) { // This is the first async operation which is happening - setup the _currentTask and timeout + Debug.Assert(context._source == null, "context._source should not be non-null when trying to change to async"); source = new TaskCompletionSource(); Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); if (original != null) @@ -5016,6 +5090,7 @@ private Task GetBytesAsyncReadDataStage(int i, byte[] buffer, int index, in return source.Task; } + context._source = source; // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) if (_cancelAsyncOnCloseToken.IsCancellationRequested) { @@ -5025,52 +5100,18 @@ private Task GetBytesAsyncReadDataStage(int i, byte[] buffer, int index, in } // Timeout - Debug.Assert(timeoutToken == CancellationToken.None, "TimeoutToken is set when GetBytesAsyncReadDataStage is not a continuation"); - if (timeout > 0) + Debug.Assert(context.timeoutToken == CancellationToken.None, "TimeoutToken is set when GetBytesAsyncReadDataStage is not a continuation"); + if (context.timeout > 0) { - timeoutCancellationSource = new CancellationTokenSource(); - timeoutCancellationSource.CancelAfter(timeout); - timeoutToken = timeoutCancellationSource.Token; + CancellationTokenSource timeoutCancellationSource = new CancellationTokenSource(); + timeoutCancellationSource.CancelAfter(context.timeout); + Debug.Assert(context._disposable is null, "setting context.disposable would lose the previous dispoable"); + context._disposable = timeoutCancellationSource; + context.timeoutToken = timeoutCancellationSource.Token; } } - Func> moreFunc = null; - moreFunc = (_ => - { - PrepareForAsyncContinuation(); - - if (cancellationToken.IsCancellationRequested) - { - // User requested cancellation - return ADP.CreatedTaskWithCancellation(); - } - else if (timeoutToken.IsCancellationRequested) - { - // Timeout - return ADP.CreatedTaskWithException(ADP.ExceptionWithStackTrace(ADP.IO(SQLMessage.Timeout()))); - } - else - { - // Prepare for stateObj timeout - SetTimeout(_defaultTimeoutMilliseconds); - - int bytesReadThisIteration; - bool result = TryGetBytesInternalSequential(i, buffer, index + totalBytesRead, length - totalBytesRead, out bytesReadThisIteration); - totalBytesRead += bytesReadThisIteration; - Debug.Assert(totalBytesRead <= length, "Read more bytes than required"); - - if (result) - { - return Task.FromResult(totalBytesRead); - } - else - { - return ContinueRetryable(moreFunc); - } - } - }); - - Task retryTask = ContinueRetryable(moreFunc); + Task retryTask = ExecuteAsyncCall(context); if (isContinuation) { // Let the caller handle cleanup\completing @@ -5078,8 +5119,13 @@ private Task GetBytesAsyncReadDataStage(int i, byte[] buffer, int index, in } else { + Debug.Assert(context._source != null, "context._source shuld not be null when continuing"); // setup for cleanup\completing - retryTask.ContinueWith((Task t) => CompleteRetryable(t, source, timeoutCancellationSource), TaskScheduler.Default); + retryTask.ContinueWith( + continuationAction: AAsyncCallContext.s_completeCallback, + state: context, + TaskScheduler.Default + ); return source.Task; } } @@ -5132,42 +5178,42 @@ public override Task ReadAsync(CancellationToken cancellationToken) { _stateObj._shouldHaveEnoughData = true; #endif - if (_sharedState._dataReady) - { - // Clean off current row - CleanPartialReadReliable(); - } + if (_sharedState._dataReady) + { + // Clean off current row + CleanPartialReadReliable(); + } - // If there a ROW token ready (as well as any metadata for the row) - if (_stateObj.IsRowTokenReady()) - { - // Read the ROW token - bool result = TryReadInternal(true, out more); - Debug.Assert(result, "Should not have run out of data"); + // If there a ROW token ready (as well as any metadata for the row) + if (_stateObj.IsRowTokenReady()) + { + // Read the ROW token + bool result = TryReadInternal(true, out more); + Debug.Assert(result, "Should not have run out of data"); - rowTokenRead = true; - if (more) + rowTokenRead = true; + if (more) + { + // Sequential mode, nothing left to do + if (IsCommandBehavior(CommandBehavior.SequentialAccess)) { - // Sequential mode, nothing left to do - if (IsCommandBehavior(CommandBehavior.SequentialAccess)) - { - return ADP.TrueTask; - } - // For non-sequential, check if we can read the row data now - else if (WillHaveEnoughData(_metaData.Length - 1)) - { - // Read row data - result = TryReadColumn(_metaData.Length - 1, setTimeout: true); - Debug.Assert(result, "Should not have run out of data"); - return ADP.TrueTask; - } + return ADP.TrueTask; } - else + // For non-sequential, check if we can read the row data now + else if (WillHaveEnoughData(_metaData.Length - 1)) { - // No data left, return - return ADP.FalseTask; + // Read row data + result = TryReadColumn(_metaData.Length - 1, setTimeout: true); + Debug.Assert(result, "Should not have run out of data"); + return ADP.TrueTask; } } + else + { + // No data left, return + return ADP.FalseTask; + } + } #if DEBUG } finally @@ -5205,53 +5251,68 @@ public override Task ReadAsync(CancellationToken cancellationToken) IDisposable registration = null; if (cancellationToken.CanBeCanceled) { - registration = cancellationToken.Register(_command.CancelIgnoreFailure); + registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); } + var context = Interlocked.Exchange(ref _cachedReadAsyncContext, null) ?? new ReadAsyncCallContext(); + + Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ReadAsyncCallContext was not properly disposed"); + + context.Set(this, source, registration); + context._hasMoreData = more; + context._hasReadRowToken = rowTokenRead; + PrepareAsyncInvocation(useSnapshot: true); - Func> moreFunc = null; - moreFunc = (t) => + return InvokeAsyncCall(context); + } + } + + private static Task ReadAsyncExecute(Task task, object state) + { + var context = (ReadAsyncCallContext)state; + SqlDataReader reader = context._reader; + ref bool hasMoreData = ref context._hasMoreData; + ref bool hasReadRowToken = ref context._hasReadRowToken; + + if (task != null) + { + reader.PrepareForAsyncContinuation(); + } + + if (hasReadRowToken || reader.TryReadInternal(true, out hasMoreData)) + { + // If there are no more rows, or this is Sequential Access, then we are done + if (!hasMoreData || (reader._commandBehavior & CommandBehavior.SequentialAccess) == CommandBehavior.SequentialAccess) + { + // completed + return hasMoreData ? ADP.TrueTask : ADP.FalseTask; + } + else { - if (t != null) + // First time reading the row token - update the snapshot + if (!hasReadRowToken) { - SqlClientEventSource.Log.TryTraceEvent(" attempt retry {0}", ObjectID); - PrepareForAsyncContinuation(); + hasReadRowToken = true; + reader._snapshot = null; + reader.PrepareAsyncInvocation(useSnapshot: true); } - if (rowTokenRead || TryReadInternal(true, out more)) + // if non-sequentialaccess then read entire row before returning + if (reader.TryReadColumn(reader._metaData.Length - 1, true)) { - - // If there are no more rows, or this is Sequential Access, then we are done - if (!more || (_commandBehavior & CommandBehavior.SequentialAccess) == CommandBehavior.SequentialAccess) - { - // completed - return more ? ADP.TrueTask : ADP.FalseTask; - } - else - { - // First time reading the row token - update the snapshot - if (!rowTokenRead) - { - rowTokenRead = true; - _snapshot = null; - PrepareAsyncInvocation(useSnapshot: true); - } - - // if non-sequentialaccess then read entire row before returning - if (TryReadColumn(_metaData.Length - 1, true)) - { - // completed - return ADP.TrueTask; - } - } + // completed + return ADP.TrueTask; } + } + } - return ContinueRetryable(moreFunc); - }; + return reader.ExecuteAsyncCall(context); + } - return InvokeRetryable(moreFunc, source, registration); - } + private void SetCachedReadAsyncCallContext(ReadAsyncCallContext instance) + { + Interlocked.CompareExchange(ref _cachedReadAsyncContext, instance, null); } /// @@ -5309,8 +5370,8 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo { _stateObj._shouldHaveEnoughData = true; #endif - ReadColumnHeader(i); - return _data[i].IsNull ? ADP.TrueTask : ADP.FalseTask; + ReadColumnHeader(i); + return _data[i].IsNull ? ADP.TrueTask : ADP.FalseTask; #if DEBUG } finally @@ -5350,40 +5411,51 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo IDisposable registration = null; if (cancellationToken.CanBeCanceled) { - registration = cancellationToken.Register(_command.CancelIgnoreFailure); + registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); } + IsDBNullAsyncCallContext context = Interlocked.Exchange(ref _cachedIsDBNullContext, null) ?? new IsDBNullAsyncCallContext(); + + Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ISDBNullAsync context not properly disposed"); + + context.Set(this, source, registration); + context._columnIndex = i; + // Setup async PrepareAsyncInvocation(useSnapshot: true); - // Setup the retryable function - Func> moreFunc = null; - moreFunc = (t) => - { - if (t != null) - { - PrepareForAsyncContinuation(); - } + return InvokeAsyncCall(context); + } + } - if (TryReadColumnHeader(i)) - { - return _data[i].IsNull ? ADP.TrueTask : ADP.FalseTask; - } - else - { - return ContinueRetryable(moreFunc); - } - }; + private static Task IsDBNullAsyncExecute(Task task, object state) + { + IsDBNullAsyncCallContext context = (IsDBNullAsyncCallContext)state; + SqlDataReader reader = context._reader; + + if (task != null) + { + reader.PrepareForAsyncContinuation(); + } - // Go! - return InvokeRetryable(moreFunc, source, registration); + if (reader.TryReadColumnHeader(context._columnIndex)) + { + return reader._data[context._columnIndex].IsNull ? ADP.TrueTask : ADP.FalseTask; + } + else + { + return reader.ExecuteAsyncCall(context); } } + private void SetCachedIDBNullAsyncCallContext(IsDBNullAsyncCallContext instance) + { + Interlocked.CompareExchange(ref _cachedIsDBNullContext, instance, null); + } + /// override public Task GetFieldValueAsync(int i, CancellationToken cancellationToken) { - try { CheckDataIsReady(columnIndex: i, methodName: "GetFieldValueAsync"); @@ -5435,7 +5507,7 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat { _stateObj._shouldHaveEnoughData = true; #endif - return Task.FromResult(GetFieldValueInternal(i)); + return Task.FromResult(GetFieldValueInternal(i)); #if DEBUG } finally @@ -5475,33 +5547,30 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat IDisposable registration = null; if (cancellationToken.CanBeCanceled) { - registration = cancellationToken.Register(_command.CancelIgnoreFailure); + registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); } - // Setup async - PrepareAsyncInvocation(useSnapshot: true); + return InvokeAsyncCall(new GetFieldValueAsyncCallContext(this, source, registration, i)); + } - // Setup the retryable function - Func> moreFunc = null; - moreFunc = (t) => + private static Task GetFieldValueAsyncExecute(Task task, object state) + { + GetFieldValueAsyncCallContext context = (GetFieldValueAsyncCallContext)state; + SqlDataReader reader = context._reader; + int columnIndex = context._columnIndex; + if (task != null) { - if (t != null) - { - PrepareForAsyncContinuation(); - } - - if (TryReadColumn(i, setTimeout: false)) - { - return Task.FromResult(GetFieldValueFromSqlBufferInternal(_data[i], _metaData[i])); - } - else - { - return ContinueRetryable(moreFunc); - } - }; + reader.PrepareForAsyncContinuation(); + } - // Go! - return InvokeRetryable(moreFunc, source, registration); + if (reader.TryReadColumn(columnIndex, setTimeout: false)) + { + return Task.FromResult(reader.GetFieldValueFromSqlBufferInternal(reader._data[columnIndex], reader._metaData[columnIndex])); + } + else + { + return reader.ExecuteAsyncCall(context); + } } #if DEBUG @@ -5547,85 +5616,174 @@ class Snapshot public SqlSequentialTextReader _currentTextReader; } - private Task ContinueRetryable(Func> moreFunc) + private abstract class AAsyncCallContext : IDisposable { - // _networkPacketTaskSource could be null if the connection was closed - // while an async invocation was outstanding. - TaskCompletionSource completionSource = _stateObj._networkPacketTaskSource; - if (_cancelAsyncOnCloseToken.IsCancellationRequested || completionSource == null) + internal static readonly Action, object> s_completeCallback = SqlDataReader.CompleteAsyncCallCallback; + + internal static readonly Func> s_executeCallback = SqlDataReader.ExecuteAsyncCallCallback; + + internal SqlDataReader _reader; + internal TaskCompletionSource _source; + internal IDisposable _disposable; + + protected AAsyncCallContext() { - // Cancellation requested due to datareader being closed - TaskCompletionSource source = new TaskCompletionSource(); - source.TrySetException(ADP.ExceptionWithStackTrace(ADP.ClosedConnectionError())); - return source.Task; } - else + + protected AAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable = null) + { + Set(reader, source, disposable); + } + + internal void Set(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable = null) + { + this._reader = reader ?? throw new ArgumentNullException(nameof(reader)); + this._source = source ?? throw new ArgumentNullException(nameof(source)); + this._disposable = disposable; + } + + internal void Clear() + { + _source = null; + _reader = null; + IDisposable copyDisposable = _disposable; + _disposable = null; + copyDisposable?.Dispose(); + } + + internal abstract Func> Execute { get; } + + public virtual void Dispose() + { + Clear(); + } + } + + private sealed class ReadAsyncCallContext : AAsyncCallContext + { + internal static readonly Func> s_execute = SqlDataReader.ReadAsyncExecute; + + internal bool _hasMoreData; + internal bool _hasReadRowToken; + + internal ReadAsyncCallContext() + { + } + + internal override Func> Execute => s_execute; + + public override void Dispose() + { + SqlDataReader reader = this._reader; + base.Dispose(); + reader.SetCachedReadAsyncCallContext(this); + } + } + + private sealed class IsDBNullAsyncCallContext : AAsyncCallContext + { + internal static readonly Func> s_execute = SqlDataReader.IsDBNullAsyncExecute; + + internal int _columnIndex; + + internal IsDBNullAsyncCallContext() { } + + internal override Func> Execute => s_execute; + + public override void Dispose() + { + SqlDataReader reader = this._reader; + base.Dispose(); + reader.SetCachedIDBNullAsyncCallContext(this); + } + } + + private sealed class HasNextResultAsyncCallContext : AAsyncCallContext + { + private static readonly Func> s_execute = SqlDataReader.NextResultAsyncExecute; + + public HasNextResultAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable) + : base(reader, source, disposable) { - return completionSource.Task.ContinueWith((retryTask) => - { - if (retryTask.IsFaulted) - { - // Somehow the network task faulted - return the exception - TaskCompletionSource exceptionSource = new TaskCompletionSource(); - exceptionSource.TrySetException(retryTask.Exception.InnerException); - return exceptionSource.Task; - } - else if (!_cancelAsyncOnCloseToken.IsCancellationRequested) - { - TdsParserStateObject stateObj = _stateObj; - if (stateObj != null) - { - // protect continuations against concurrent - // close and cancel - lock (stateObj) - { - if (_stateObj != null) - { // reader not closed while we waited for the lock - if (retryTask.IsCanceled) - { - if (_parser != null) - { - _parser.State = TdsParserState.Broken; // We failed to respond to attention, we have to quit! - _parser.Connection.BreakConnection(); - _parser.ThrowExceptionAndWarning(_stateObj); - } - } - else - { - if (!IsClosed) - { - try - { - return moreFunc(retryTask); - } - catch (Exception) - { - CleanupAfterAsyncInvocation(); - throw; - } - } - } - } - } - } - } - // if stateObj is null, or we closed the connection or the connection was already closed, - // then mark this operation as cancelled. - TaskCompletionSource source = new TaskCompletionSource(); - source.SetException(ADP.ExceptionWithStackTrace(ADP.ClosedConnectionError())); - return source.Task; - }, TaskScheduler.Default).Unwrap(); } + + internal override Func> Execute => s_execute; + } + + private sealed class GetBytesAsyncCallContext : AAsyncCallContext + { + internal enum OperationMode + { + Seek = 0, + Read = 1 + } + + private static readonly Func> s_executeSeek = SqlDataReader.GetBytesAsyncSeekExecute; + private static readonly Func> s_executeRead = SqlDataReader.GetBytesAsyncReadExecute; + + internal int columnIndex; + internal byte[] buffer; + internal int index; + internal int length; + internal int timeout; + internal CancellationToken cancellationToken; + internal CancellationToken timeoutToken; + internal int totalBytesRead; + + internal OperationMode mode; + + internal GetBytesAsyncCallContext(SqlDataReader reader) + { + this._reader = reader ?? throw new ArgumentNullException(nameof(reader)); + } + + internal override Func> Execute => mode == OperationMode.Seek ? s_executeSeek : s_executeRead; + + public override void Dispose() + { + buffer = null; + cancellationToken = default; + timeoutToken = default; + base.Dispose(); + } + } + + private sealed class GetFieldValueAsyncCallContext : AAsyncCallContext + { + private static readonly Func> s_execute = SqlDataReader.GetFieldValueAsyncExecute; + + internal readonly int _columnIndex; + + internal GetFieldValueAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable, int columnIndex) + : base(reader, source, disposable) + { + _columnIndex = columnIndex; + } + + internal override Func> Execute => s_execute; } - private Task InvokeRetryable(Func> moreFunc, TaskCompletionSource source, IDisposable objectToDispose = null) + private static Task ExecuteAsyncCallCallback(Task task, object state) { + AAsyncCallContext context = (AAsyncCallContext)state; + return context._reader.ExecuteAsyncCall(task, context); + } + + private static void CompleteAsyncCallCallback(Task task, object state) + { + AAsyncCallContext context = (AAsyncCallContext)state; + context._reader.CompleteAsyncCall(task, context); + } + + private Task InvokeAsyncCall(AAsyncCallContext context) + { + TaskCompletionSource source = context._source; try { Task task; try { - task = moreFunc(null); + task = context.Execute(null, context); } catch (Exception ex) { @@ -5634,11 +5792,15 @@ private Task InvokeRetryable(Func> moreFunc, TaskCompletionS if (task.IsCompleted) { - CompleteRetryable(task, source, objectToDispose); + CompleteAsyncCall(task, context); } else { - task.ContinueWith((Task t) => CompleteRetryable(t, source, objectToDispose), TaskScheduler.Default); + task.ContinueWith( + continuationAction: AAsyncCallContext.s_completeCallback, + state: context, + TaskScheduler.Default + ); } } catch (AggregateException e) @@ -5654,17 +5816,110 @@ private Task InvokeRetryable(Func> moreFunc, TaskCompletionS return source.Task; } - private void CompleteRetryable(Task task, TaskCompletionSource source, IDisposable objectToDispose) + /// + /// Begins an async call checking for cancellation and then setting up the callback for when data is available + /// + /// + /// + /// + private Task ExecuteAsyncCall(AAsyncCallContext context) + { + // _networkPacketTaskSource could be null if the connection was closed + // while an async invocation was outstanding. + TaskCompletionSource completionSource = _stateObj._networkPacketTaskSource; + if (_cancelAsyncOnCloseToken.IsCancellationRequested || completionSource == null) + { + // Cancellation requested due to datareader being closed + return Task.FromException(ADP.ExceptionWithStackTrace(ADP.ClosedConnectionError())); + } + else + { + return completionSource.Task.ContinueWith( + continuationFunction: AAsyncCallContext.s_executeCallback, + state: context, + TaskScheduler.Default + ).Unwrap(); + } + } + + /// + /// When data has become available for an async call it is woken and this method is called. + /// It will call the async execution func and if a Task is returned indicating more data + /// is needed it will wait until it is called again when more is available + /// + /// + /// + /// + /// + private Task ExecuteAsyncCall(Task task, AAsyncCallContext context) { - if (objectToDispose != null) + // this function must be an instance function called from the static callback because otherwise a compiler error + // is caused by accessing the _cancelAsyncOnCloseToken field of a MarchalByRefObject derived class + if (task.IsFaulted) { - objectToDispose.Dispose(); + // Somehow the network task faulted - return the exception + return Task.FromException(task.Exception.InnerException); } + else if (!_cancelAsyncOnCloseToken.IsCancellationRequested) + { + TdsParserStateObject stateObj = _stateObj; + if (stateObj != null) + { + // protect continuations against concurrent + // close and cancel + lock (stateObj) + { + if (_stateObj != null) + { // reader not closed while we waited for the lock + if (task.IsCanceled) + { + if (_parser != null) + { + _parser.State = TdsParserState.Broken; // We failed to respond to attention, we have to quit! + _parser.Connection.BreakConnection(); + _parser.ThrowExceptionAndWarning(_stateObj); + } + } + else + { + if (!IsClosed) + { + try + { + return context.Execute(task, context); + } + catch (Exception) + { + CleanupAfterAsyncInvocation(); + throw; + } + } + } + } + } + } + } + // if stateObj is null, or we closed the connection or the connection was already closed, + // then mark this operation as cancelled. + return Task.FromException(ADP.ExceptionWithStackTrace(ADP.ClosedConnectionError())); + } + + /// + /// When data has been successfully processed for an async call the async func will call this + /// function to set the result into the task and cleanup the async state ready for another call + /// + /// + /// + /// + private void CompleteAsyncCall(Task task, AAsyncCallContext context) + { + TaskCompletionSource source = context._source; + context.Dispose(); // If something has forced us to switch to SyncOverAsync mode while in an async task then we need to guarantee that we do the cleanup // This avoids us replaying non-replayable data (such as DONE or ENV_CHANGE tokens) var stateObj = _stateObj; - bool ignoreCloseToken = ((stateObj != null) && (stateObj._syncOverAsync)); + bool ignoreCloseToken = (stateObj != null) && (stateObj._syncOverAsync); CleanupAfterAsyncInvocation(ignoreCloseToken); Task current = Interlocked.CompareExchange(ref _currentTask, null, source.Task); diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataStreamTest/DataStreamTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataStreamTest/DataStreamTest.cs index a5b67acdd5..2c5bce8c82 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataStreamTest/DataStreamTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataStreamTest/DataStreamTest.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Data; using System.Data.SqlTypes; using System.IO; @@ -43,7 +44,7 @@ public static async Task AsyncMultiPacketStreamRead() { int packetSize = 514; // force small packet size so we can quickly check multi packet reads - SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(DataTestUtility.TCPConnectionString); + SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString); builder.PacketSize = 514; string connectionString = builder.ToString(); @@ -51,7 +52,7 @@ public static async Task AsyncMultiPacketStreamRead() byte[] outputData = null; string tableName = DataTestUtility.GetUniqueNameForSqlServer("data"); - using (SqlConnection connection = new SqlConnection(connectionString)) + using (SqlConnection connection = new(connectionString)) { await connection.OpenAsync(); @@ -59,19 +60,17 @@ public static async Task AsyncMultiPacketStreamRead() { inputData = CreateBinaryTable(connection, tableName, packetSize); - using (SqlCommand command = new SqlCommand($"SELECT foo FROM {tableName}", connection)) - using (SqlDataReader reader = await command.ExecuteReaderAsync(System.Data.CommandBehavior.SequentialAccess)) - { - await reader.ReadAsync(); + using SqlCommand command = new($"SELECT foo FROM {tableName}", connection); + using SqlDataReader reader = await command.ExecuteReaderAsync(System.Data.CommandBehavior.SequentialAccess); + await reader.ReadAsync(); - using (Stream stream = reader.GetStream(0)) - using (CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(60))) - using (MemoryStream memory = new MemoryStream(16 * 1024)) - { - await stream.CopyToAsync(memory, 37, cancellationTokenSource.Token); // prime number sized buffer to cause many cross packet partial reads - outputData = memory.ToArray(); - } - } + using Stream stream = reader.GetStream(0); + using CancellationTokenSource cancellationTokenSource = new(TimeSpan.FromSeconds(60)); + using MemoryStream memory = new(16 * 1024); + + // prime number sized buffer to cause many cross packet partial reads + await LocalCopyTo(stream, memory, 37, cancellationTokenSource.Token); + outputData = memory.ToArray(); } finally { @@ -79,6 +78,30 @@ public static async Task AsyncMultiPacketStreamRead() } } + static async Task LocalCopyTo(Stream source, Stream destination, int bufferSize, CancellationToken cancellationToken) + { + byte[] buffer = ArrayPool.Shared.Rent(bufferSize); + try + { + int bytesRead; +#if NETFRAMEWORK + while ((bytesRead = await source.ReadAsync(buffer, 0, bufferSize, cancellationToken).ConfigureAwait(false)) != 0) + { + await destination.WriteAsync(buffer, 0, bytesRead, cancellationToken).ConfigureAwait(false); + } +#else + while ((bytesRead = await source.ReadAsync(new Memory(buffer,0, bufferSize), cancellationToken).ConfigureAwait(false)) != 0) + { + await destination.WriteAsync(new ReadOnlyMemory(buffer, 0, bytesRead), cancellationToken).ConfigureAwait(false); + } +#endif + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + Assert.NotNull(outputData); int sharedLength = Math.Min(inputData.Length, outputData.Length); if (sharedLength < outputData.Length)