diff --git a/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs b/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs index 38d4a8cd2ce..c6e2bace6f7 100644 --- a/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs +++ b/src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs @@ -269,6 +269,7 @@ public virtual Stream GetStream(int ordinal) if (!_rowidOrdinal.HasValue) { _rowidOrdinal = -1; + var pkColumns = -1L; for (var i = 0; i < FieldCount; i++) { @@ -310,8 +311,22 @@ public virtual Stream GetStream(int ordinal) if (string.Equals(dataType, "INTEGER", StringComparison.OrdinalIgnoreCase) && primaryKey != 0) { - _rowidOrdinal = i; - break; + if (pkColumns < 0L) + { + using (var command = _connection.CreateCommand()) + { + command.CommandText = "SELECT COUNT(*) FROM pragma_table_info($table) WHERE pk != 0;"; + command.Parameters.AddWithValue("$table", tableName); + + pkColumns = (long)command.ExecuteScalar()!; + } + } + + if (pkColumns == 1L) + { + _rowidOrdinal = i; + break; + } } } diff --git a/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs b/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs index 7a548763880..ee1af0426a1 100644 --- a/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs +++ b/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs @@ -456,6 +456,62 @@ public void GetStream_Blob_works(string createTableCmd, string selectCmd) } } + [Fact] + public void GetStream_works_when_composite_pk() + { + using (var connection = new SqliteConnection("Data Source=:memory:")) + { + connection.Open(); + + connection.ExecuteNonQuery( + @"CREATE TABLE DataTable (Id1 INTEGER, Id2 INTEGER, Data BLOB, PRIMARY KEY (Id1, Id2)); + INSERT INTO DataTable VALUES (5, 6, X'01020304');"); + + var selectCommand = connection.CreateCommand(); + selectCommand.CommandText = "SELECT Id1, Id2, Data FROM DataTable WHERE Id1 = 5 AND Id2 = 6"; + using (var reader = selectCommand.ExecuteReader()) + { + Assert.True(reader.Read()); + using (var sourceStream = reader.GetStream(2)) + { + Assert.IsType(sourceStream); + var buffer = new byte[4]; + var bytesRead = sourceStream.Read(buffer, 0, 4); + Assert.Equal(4, bytesRead); + Assert.Equal(new byte[] { 0x01, 0x02, 0x03, 0x04 }, buffer); + } + } + } + } + + [Fact] + public void GetStream_works_when_composite_pk_and_rowid() + { + using (var connection = new SqliteConnection("Data Source=:memory:")) + { + connection.Open(); + + connection.ExecuteNonQuery( + @"CREATE TABLE DataTable (Id1 INTEGER, Id2 INTEGER, Data BLOB, PRIMARY KEY (Id1, Id2)); + INSERT INTO DataTable VALUES (5, 6, X'01020304');"); + + var selectCommand = connection.CreateCommand(); + selectCommand.CommandText = "SELECT Id1, Id2, rowid, Data FROM DataTable WHERE Id1 = 5 AND Id2 = 6"; + using (var reader = selectCommand.ExecuteReader()) + { + Assert.True(reader.Read()); + using (var sourceStream = reader.GetStream(3)) + { + Assert.IsType(sourceStream); + var buffer = new byte[4]; + var bytesRead = sourceStream.Read(buffer, 0, 4); + Assert.Equal(4, bytesRead); + Assert.Equal(new byte[] { 0x01, 0x02, 0x03, 0x04 }, buffer); + } + } + } + } + [Fact] public void GetStream_throws_when_closed() {