diff --git a/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs b/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs
index 95401f84ee0..f07af146e25 100644
--- a/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs
+++ b/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs
@@ -177,7 +177,7 @@ public override bool NextResult()
// It's a SELECT statement
if (sqlite3_column_count(stmt) != 0)
{
- _record = new SqliteDataRecord(stmt, rc != SQLITE_DONE, _command.Connection);
+ _record = new SqliteDataRecord(stmt, rc != SQLITE_DONE, _command.Connection, AddChanges);
return true;
}
@@ -191,14 +191,7 @@ public override bool NextResult()
sqlite3_reset(stmt);
var changes = sqlite3_changes(_command.Connection.Handle);
- if (_recordsAffected == -1)
- {
- _recordsAffected = changes;
- }
- else
- {
- _recordsAffected += changes;
- }
+ AddChanges(changes);
}
catch
{
@@ -219,6 +212,18 @@ private static bool IsBusy(int rc)
|| rc == SQLITE_BUSY
|| rc == SQLITE_LOCKED_SHAREDCACHE;
+ private void AddChanges(int changes)
+ {
+ if (_recordsAffected == -1)
+ {
+ _recordsAffected = changes;
+ }
+ else
+ {
+ _recordsAffected += changes;
+ }
+ }
+
///
/// Closes the data reader.
///
@@ -242,6 +247,7 @@ protected override void Dispose(bool disposing)
_command.DataReader = null;
_record?.Dispose();
+ _record = null;
if (_stmtEnumerator != null)
{
@@ -249,7 +255,6 @@ protected override void Dispose(bool disposing)
{
while (NextResult())
{
- _record!.Dispose();
}
}
catch
diff --git a/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs b/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs
index 816d0db0474..52a996fa4b9 100644
--- a/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs
+++ b/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs
@@ -16,18 +16,22 @@ namespace Microsoft.Data.Sqlite
internal class SqliteDataRecord : SqliteValueReader, IDisposable
{
private readonly SqliteConnection _connection;
+ private readonly Action _addChanges;
private byte[][]? _blobCache;
private int?[]? _typeCache;
private Dictionary? _columnNameOrdinalCache;
private string[]? _columnNameCache;
private bool _stepped;
private int? _rowidOrdinal;
+ private bool _alreadyThrown;
+ private bool _alreadyAddedChanges;
- public SqliteDataRecord(sqlite3_stmt stmt, bool hasRows, SqliteConnection connection)
+ public SqliteDataRecord(sqlite3_stmt stmt, bool hasRows, SqliteConnection connection, Action addChanges)
{
Handle = stmt;
HasRows = hasRows;
_connection = connection;
+ _addChanges = addChanges;
}
public virtual object this[string name]
@@ -397,19 +401,59 @@ public bool Read()
return false;
}
- var rc = sqlite3_step(Handle);
- SqliteException.ThrowExceptionForRC(rc, _connection.Handle);
+ int rc;
+ try
+ {
+ rc = sqlite3_step(Handle);
+ SqliteException.ThrowExceptionForRC(rc, _connection.Handle);
+ }
+ catch
+ {
+ _alreadyThrown = true;
+
+ throw;
+ }
if (_blobCache != null)
{
Array.Clear(_blobCache, 0, _blobCache.Length);
}
- return rc != SQLITE_DONE;
+ if (rc != SQLITE_DONE)
+ {
+ return true;
+ }
+
+ AddChanges();
+ _alreadyAddedChanges = true;
+
+ return false;
}
public void Dispose()
- => sqlite3_reset(Handle);
+ {
+ var rc = sqlite3_reset(Handle);
+ if (!_alreadyThrown)
+ {
+ SqliteException.ThrowExceptionForRC(rc, _connection.Handle);
+ }
+
+ if (!_alreadyAddedChanges)
+ {
+ AddChanges();
+ }
+ }
+
+ private void AddChanges()
+ {
+ if (sqlite3_stmt_readonly(Handle) != 0)
+ {
+ return;
+ }
+
+ var changes = sqlite3_changes(_connection.Handle);
+ _addChanges(changes);
+ }
private byte[] GetCachedBlob(int ordinal)
{
diff --git a/test/Microsoft.Data.Sqlite.Tests/SqliteCommandTest.cs b/test/Microsoft.Data.Sqlite.Tests/SqliteCommandTest.cs
index 30adda5ae93..bbb13c8daba 100644
--- a/test/Microsoft.Data.Sqlite.Tests/SqliteCommandTest.cs
+++ b/test/Microsoft.Data.Sqlite.Tests/SqliteCommandTest.cs
@@ -932,6 +932,120 @@ await Task.WhenAll(
}
}
+ [Fact]
+ public Task ExecuteScalar_throws_when_busy_with_returning()
+ => Execute_throws_when_busy_with_returning(command =>
+ {
+ var ex = Assert.Throws(
+ () => command.ExecuteScalar());
+
+ Assert.Equal(SQLITE_BUSY, ex.SqliteErrorCode);
+ });
+
+ [Fact]
+ public Task ExecuteNonQuery_throws_when_busy_with_returning()
+ => Execute_throws_when_busy_with_returning(command =>
+ {
+ var ex = Assert.Throws(
+ () => command.ExecuteNonQuery());
+
+ Assert.Equal(SQLITE_BUSY, ex.SqliteErrorCode);
+ });
+
+ [Fact]
+ public Task ExecuteReader_throws_when_busy_with_returning()
+ => Execute_throws_when_busy_with_returning(command =>
+ {
+ var reader = command.ExecuteReader();
+ try
+ {
+ Assert.True(reader.Read());
+ Assert.Equal(2L, reader.GetInt64(0));
+ }
+ finally
+ {
+ var ex = Assert.Throws(
+ () => reader.Dispose());
+
+ Assert.Equal(SQLITE_BUSY, ex.SqliteErrorCode);
+ }
+ });
+
+ [Fact]
+ public Task ExecuteReader_throws_when_busy_with_returning_while_draining()
+ => Execute_throws_when_busy_with_returning(command =>
+ {
+ using var reader = command.ExecuteReader();
+ Assert.True(reader.Read());
+ Assert.Equal(2L, reader.GetInt64(0));
+ Assert.True(reader.Read());
+ Assert.Equal(3L, reader.GetInt64(0));
+
+ var ex = Assert.Throws(
+ () => reader.Read());
+
+ Assert.Equal(SQLITE_BUSY, ex.SqliteErrorCode);
+ });
+
+ private static async Task Execute_throws_when_busy_with_returning(Action action)
+ {
+ const string connectionString = "Data Source=returning.db";
+
+ var selectedSignal = new AutoResetEvent(initialState: false);
+
+ try
+ {
+ using var connection1 = new SqliteConnection(connectionString);
+
+ if (new Version(connection1.ServerVersion) < new Version(3, 35, 0))
+ {
+ // Skip. RETURNING clause not supported
+ return;
+ }
+
+ connection1.Open();
+
+ connection1.ExecuteNonQuery(
+ "CREATE TABLE Data (Value); INSERT INTO Data VALUES (0);");
+
+ await Task.WhenAll(
+ Task.Run(
+ async () =>
+ {
+ using var connection = new SqliteConnection(connectionString);
+ connection.Open();
+
+ using (connection.ExecuteReader("SELECT * FROM Data;"))
+ {
+ selectedSignal.Set();
+
+ await Task.Delay(1000);
+ }
+ }),
+ Task.Run(
+ () =>
+ {
+ using var connection = new SqliteConnection(connectionString);
+ connection.Open();
+
+ selectedSignal.WaitOne();
+
+ var command = connection.CreateCommand();
+ command.CommandText = "INSERT INTO Data VALUES (1),(2) RETURNING rowid;";
+
+ action(command);
+ }));
+
+ var count = connection1.ExecuteScalar("SELECT COUNT(*) FROM Data;");
+ Assert.Equal(1L, count);
+ }
+ finally
+ {
+ SqliteConnection.ClearPool(new SqliteConnection(connectionString));
+ File.Delete("returning.db");
+ }
+ }
+
[Fact]
public void ExecuteReader_honors_CommandTimeout()
{
diff --git a/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs b/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs
index 62a43f59737..b71658c7a08 100644
--- a/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs
+++ b/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs
@@ -1881,6 +1881,73 @@ public void RecordsAffected_works_during_enumeration()
}
}
+ [Fact]
+ public void RecordsAffected_works_with_returning()
+ {
+ using (var connection = new SqliteConnection("Data Source=:memory:"))
+ {
+ if (new Version(connection.ServerVersion) < new Version(3, 35, 0))
+ {
+ // Skip. RETURNING clause not supported
+ return;
+ }
+
+ connection.Open();
+ connection.ExecuteNonQuery("CREATE TABLE Test(Value);");
+
+ var reader = connection.ExecuteReader("INSERT INTO Test VALUES(1) RETURNING rowid;");
+ ((IDisposable)reader).Dispose();
+
+ Assert.Equal(1, reader.RecordsAffected);
+ }
+ }
+
+ [Fact]
+ public void RecordsAffected_works_with_returning_before_dispose_after_draining()
+ {
+ using (var connection = new SqliteConnection("Data Source=:memory:"))
+ {
+ if (new Version(connection.ServerVersion) < new Version(3, 35, 0))
+ {
+ // Skip. RETURNING clause not supported
+ return;
+ }
+
+ connection.Open();
+ connection.ExecuteNonQuery("CREATE TABLE Test(Value);");
+
+ using (var reader = connection.ExecuteReader("INSERT INTO Test VALUES(1) RETURNING rowid;"))
+ {
+ while (reader.Read())
+ {
+ }
+
+ Assert.Equal(1, reader.RecordsAffected);
+ }
+ }
+ }
+
+ [Fact]
+ public void RecordsAffected_works_with_returning_multiple()
+ {
+ using (var connection = new SqliteConnection("Data Source=:memory:"))
+ {
+ if (new Version(connection.ServerVersion) < new Version(3, 35, 0))
+ {
+ // Skip. RETURNING clause not supported
+ return;
+ }
+
+ connection.Open();
+ connection.ExecuteNonQuery("CREATE TABLE Test(Value);");
+
+ var reader = connection.ExecuteReader("INSERT INTO Test VALUES(1),(2) RETURNING rowid;");
+ ((IDisposable)reader).Dispose();
+
+ Assert.Equal(2, reader.RecordsAffected);
+ }
+ }
+
[Fact]
public void GetSchemaTable_works()
{