diff --git a/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs b/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs index 95401f84ee0..f07af146e25 100644 --- a/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs +++ b/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs @@ -177,7 +177,7 @@ public override bool NextResult() // It's a SELECT statement if (sqlite3_column_count(stmt) != 0) { - _record = new SqliteDataRecord(stmt, rc != SQLITE_DONE, _command.Connection); + _record = new SqliteDataRecord(stmt, rc != SQLITE_DONE, _command.Connection, AddChanges); return true; } @@ -191,14 +191,7 @@ public override bool NextResult() sqlite3_reset(stmt); var changes = sqlite3_changes(_command.Connection.Handle); - if (_recordsAffected == -1) - { - _recordsAffected = changes; - } - else - { - _recordsAffected += changes; - } + AddChanges(changes); } catch { @@ -219,6 +212,18 @@ private static bool IsBusy(int rc) || rc == SQLITE_BUSY || rc == SQLITE_LOCKED_SHAREDCACHE; + private void AddChanges(int changes) + { + if (_recordsAffected == -1) + { + _recordsAffected = changes; + } + else + { + _recordsAffected += changes; + } + } + /// /// Closes the data reader. /// @@ -242,6 +247,7 @@ protected override void Dispose(bool disposing) _command.DataReader = null; _record?.Dispose(); + _record = null; if (_stmtEnumerator != null) { @@ -249,7 +255,6 @@ protected override void Dispose(bool disposing) { while (NextResult()) { - _record!.Dispose(); } } catch diff --git a/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs b/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs index 816d0db0474..52a996fa4b9 100644 --- a/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs +++ b/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs @@ -16,18 +16,22 @@ namespace Microsoft.Data.Sqlite internal class SqliteDataRecord : SqliteValueReader, IDisposable { private readonly SqliteConnection _connection; + private readonly Action _addChanges; private byte[][]? _blobCache; private int?[]? _typeCache; private Dictionary? _columnNameOrdinalCache; private string[]? _columnNameCache; private bool _stepped; private int? _rowidOrdinal; + private bool _alreadyThrown; + private bool _alreadyAddedChanges; - public SqliteDataRecord(sqlite3_stmt stmt, bool hasRows, SqliteConnection connection) + public SqliteDataRecord(sqlite3_stmt stmt, bool hasRows, SqliteConnection connection, Action addChanges) { Handle = stmt; HasRows = hasRows; _connection = connection; + _addChanges = addChanges; } public virtual object this[string name] @@ -397,19 +401,59 @@ public bool Read() return false; } - var rc = sqlite3_step(Handle); - SqliteException.ThrowExceptionForRC(rc, _connection.Handle); + int rc; + try + { + rc = sqlite3_step(Handle); + SqliteException.ThrowExceptionForRC(rc, _connection.Handle); + } + catch + { + _alreadyThrown = true; + + throw; + } if (_blobCache != null) { Array.Clear(_blobCache, 0, _blobCache.Length); } - return rc != SQLITE_DONE; + if (rc != SQLITE_DONE) + { + return true; + } + + AddChanges(); + _alreadyAddedChanges = true; + + return false; } public void Dispose() - => sqlite3_reset(Handle); + { + var rc = sqlite3_reset(Handle); + if (!_alreadyThrown) + { + SqliteException.ThrowExceptionForRC(rc, _connection.Handle); + } + + if (!_alreadyAddedChanges) + { + AddChanges(); + } + } + + private void AddChanges() + { + if (sqlite3_stmt_readonly(Handle) != 0) + { + return; + } + + var changes = sqlite3_changes(_connection.Handle); + _addChanges(changes); + } private byte[] GetCachedBlob(int ordinal) { diff --git a/test/Microsoft.Data.Sqlite.Tests/SqliteCommandTest.cs b/test/Microsoft.Data.Sqlite.Tests/SqliteCommandTest.cs index 30adda5ae93..bbb13c8daba 100644 --- a/test/Microsoft.Data.Sqlite.Tests/SqliteCommandTest.cs +++ b/test/Microsoft.Data.Sqlite.Tests/SqliteCommandTest.cs @@ -932,6 +932,120 @@ await Task.WhenAll( } } + [Fact] + public Task ExecuteScalar_throws_when_busy_with_returning() + => Execute_throws_when_busy_with_returning(command => + { + var ex = Assert.Throws( + () => command.ExecuteScalar()); + + Assert.Equal(SQLITE_BUSY, ex.SqliteErrorCode); + }); + + [Fact] + public Task ExecuteNonQuery_throws_when_busy_with_returning() + => Execute_throws_when_busy_with_returning(command => + { + var ex = Assert.Throws( + () => command.ExecuteNonQuery()); + + Assert.Equal(SQLITE_BUSY, ex.SqliteErrorCode); + }); + + [Fact] + public Task ExecuteReader_throws_when_busy_with_returning() + => Execute_throws_when_busy_with_returning(command => + { + var reader = command.ExecuteReader(); + try + { + Assert.True(reader.Read()); + Assert.Equal(2L, reader.GetInt64(0)); + } + finally + { + var ex = Assert.Throws( + () => reader.Dispose()); + + Assert.Equal(SQLITE_BUSY, ex.SqliteErrorCode); + } + }); + + [Fact] + public Task ExecuteReader_throws_when_busy_with_returning_while_draining() + => Execute_throws_when_busy_with_returning(command => + { + using var reader = command.ExecuteReader(); + Assert.True(reader.Read()); + Assert.Equal(2L, reader.GetInt64(0)); + Assert.True(reader.Read()); + Assert.Equal(3L, reader.GetInt64(0)); + + var ex = Assert.Throws( + () => reader.Read()); + + Assert.Equal(SQLITE_BUSY, ex.SqliteErrorCode); + }); + + private static async Task Execute_throws_when_busy_with_returning(Action action) + { + const string connectionString = "Data Source=returning.db"; + + var selectedSignal = new AutoResetEvent(initialState: false); + + try + { + using var connection1 = new SqliteConnection(connectionString); + + if (new Version(connection1.ServerVersion) < new Version(3, 35, 0)) + { + // Skip. RETURNING clause not supported + return; + } + + connection1.Open(); + + connection1.ExecuteNonQuery( + "CREATE TABLE Data (Value); INSERT INTO Data VALUES (0);"); + + await Task.WhenAll( + Task.Run( + async () => + { + using var connection = new SqliteConnection(connectionString); + connection.Open(); + + using (connection.ExecuteReader("SELECT * FROM Data;")) + { + selectedSignal.Set(); + + await Task.Delay(1000); + } + }), + Task.Run( + () => + { + using var connection = new SqliteConnection(connectionString); + connection.Open(); + + selectedSignal.WaitOne(); + + var command = connection.CreateCommand(); + command.CommandText = "INSERT INTO Data VALUES (1),(2) RETURNING rowid;"; + + action(command); + })); + + var count = connection1.ExecuteScalar("SELECT COUNT(*) FROM Data;"); + Assert.Equal(1L, count); + } + finally + { + SqliteConnection.ClearPool(new SqliteConnection(connectionString)); + File.Delete("returning.db"); + } + } + [Fact] public void ExecuteReader_honors_CommandTimeout() { diff --git a/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs b/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs index 62a43f59737..b71658c7a08 100644 --- a/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs +++ b/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs @@ -1881,6 +1881,73 @@ public void RecordsAffected_works_during_enumeration() } } + [Fact] + public void RecordsAffected_works_with_returning() + { + using (var connection = new SqliteConnection("Data Source=:memory:")) + { + if (new Version(connection.ServerVersion) < new Version(3, 35, 0)) + { + // Skip. RETURNING clause not supported + return; + } + + connection.Open(); + connection.ExecuteNonQuery("CREATE TABLE Test(Value);"); + + var reader = connection.ExecuteReader("INSERT INTO Test VALUES(1) RETURNING rowid;"); + ((IDisposable)reader).Dispose(); + + Assert.Equal(1, reader.RecordsAffected); + } + } + + [Fact] + public void RecordsAffected_works_with_returning_before_dispose_after_draining() + { + using (var connection = new SqliteConnection("Data Source=:memory:")) + { + if (new Version(connection.ServerVersion) < new Version(3, 35, 0)) + { + // Skip. RETURNING clause not supported + return; + } + + connection.Open(); + connection.ExecuteNonQuery("CREATE TABLE Test(Value);"); + + using (var reader = connection.ExecuteReader("INSERT INTO Test VALUES(1) RETURNING rowid;")) + { + while (reader.Read()) + { + } + + Assert.Equal(1, reader.RecordsAffected); + } + } + } + + [Fact] + public void RecordsAffected_works_with_returning_multiple() + { + using (var connection = new SqliteConnection("Data Source=:memory:")) + { + if (new Version(connection.ServerVersion) < new Version(3, 35, 0)) + { + // Skip. RETURNING clause not supported + return; + } + + connection.Open(); + connection.ExecuteNonQuery("CREATE TABLE Test(Value);"); + + var reader = connection.ExecuteReader("INSERT INTO Test VALUES(1),(2) RETURNING rowid;"); + ((IDisposable)reader).Dispose(); + + Assert.Equal(2, reader.RecordsAffected); + } + } + [Fact] public void GetSchemaTable_works() {