Skip to content

Commit d12cd98

Browse files
committed
Improve Query performance and return an error if the query includes multiple statements
This commit changes Query to return an error if more than one SQL statement is provided. Previously, this library would only execute the last query statement. It also improves query construction performance by ~15%. This is a breaking a change since existing programs may rely on the broken mattn/go-sqlite3 implementation. That said, any program relying on this is also broken / using sqlite3 incorrectly. ``` goos: darwin goarch: arm64 pkg: github.com/charlievieth/go-sqlite3 cpu: Apple M4 Pro │ x1.txt │ x2.txt │ │ sec/op │ sec/op vs base │ Suite/BenchmarkQuery-14 2.255µ ± 1% 1.837µ ± 1% -18.56% (p=0.000 n=10) Suite/BenchmarkQuerySimple-14 1.322µ ± 9% 1.124µ ± 4% -15.02% (p=0.000 n=10) geomean 1.727µ 1.436µ -16.81% │ x1.txt │ x2.txt │ │ B/op │ B/op vs base │ Suite/BenchmarkQuery-14 664.0 ± 0% 656.0 ± 0% -1.20% (p=0.000 n=10) Suite/BenchmarkQuerySimple-14 472.0 ± 0% 456.0 ± 0% -3.39% (p=0.000 n=10) geomean 559.8 546.9 -2.30% │ x1.txt │ x2.txt │ │ allocs/op │ allocs/op vs base │ Suite/BenchmarkQuery-14 23.00 ± 0% 22.00 ± 0% -4.35% (p=0.000 n=10) Suite/BenchmarkQuerySimple-14 14.00 ± 0% 13.00 ± 0% -7.14% (p=0.000 n=10) geomean 17.94 16.91 -5.76% ```
1 parent 3230831 commit d12cd98

File tree

2 files changed

+253
-36
lines changed

2 files changed

+253
-36
lines changed

sqlite3.go

Lines changed: 85 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,59 @@ _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_
137137
}
138138
#endif
139139
140+
#define GO_SQLITE_MULTIPLE_QUERIES -1
141+
142+
// Our own implementation of ctype.h's isspace (for simplicity and to avoid
143+
// whatever locale shenanigans are involved with the Libc's isspace).
144+
static int _sqlite3_isspace(unsigned char c) {
145+
return c == ' ' || c - '\t' < 5;
146+
}
147+
148+
static int _sqlite3_prepare_query(sqlite3 *db, const char *zSql, int nBytes,
149+
sqlite3_stmt **ppStmt, int *paramCount) {
150+
151+
const char *tail;
152+
int rc = _sqlite3_prepare_v2_internal(db, zSql, nBytes, ppStmt, &tail);
153+
if (rc != SQLITE_OK) {
154+
return rc;
155+
}
156+
*paramCount = sqlite3_bind_parameter_count(*ppStmt);
157+
158+
// Check if the SQL query contains multiple statements.
159+
160+
// Trim leading space to handle queries with trailing whitespace.
161+
// This can save us an additional call to sqlite3_prepare_v2.
162+
const char *end = zSql + nBytes;
163+
while (tail < end && _sqlite3_isspace(*tail)) {
164+
tail++;
165+
}
166+
nBytes -= (tail - zSql);
167+
168+
// Attempt to parse the remaining SQL, if any.
169+
if (nBytes > 0 && *tail) {
170+
sqlite3_stmt *stmt;
171+
rc = _sqlite3_prepare_v2_internal(db, tail, nBytes, &stmt, NULL);
172+
if (rc != SQLITE_OK) {
173+
// sqlite3 will return OK and a NULL statement if it was
174+
goto error;
175+
}
176+
if (stmt != NULL) {
177+
sqlite3_finalize(stmt);
178+
rc = GO_SQLITE_MULTIPLE_QUERIES;
179+
goto error;
180+
}
181+
}
182+
183+
// Ok, the SQL contained one valid statement.
184+
return SQLITE_OK;
185+
186+
error:
187+
if (*ppStmt) {
188+
sqlite3_finalize(*ppStmt);
189+
}
190+
return rc;
191+
}
192+
140193
static int _sqlite3_prepare_v2(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, int *oBytes) {
141194
const char *tail = NULL;
142195
int rv = _sqlite3_prepare_v2_internal(db, zSql, nBytes, ppStmt, &tail);
@@ -1125,46 +1178,42 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
11251178
return c.query(context.Background(), query, list)
11261179
}
11271180

1181+
var closedRows = &SQLiteRows{s: &SQLiteStmt{closed: true}, closed: true}
1182+
11281183
func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
1129-
start := 0
1130-
for {
1131-
stmtArgs := make([]driver.NamedValue, 0, len(args))
1132-
s, err := c.prepare(ctx, query)
1133-
if err != nil {
1134-
return nil, err
1135-
}
1136-
s.(*SQLiteStmt).cls = true
1137-
na := s.NumInput()
1138-
if len(args)-start < na {
1139-
s.Close()
1140-
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start)
1141-
}
1142-
// consume the number of arguments used in the current
1143-
// statement and append all named arguments not contained
1144-
// therein
1145-
stmtArgs = append(stmtArgs, args[start:start+na]...)
1146-
for i := range args {
1147-
if (i < start || i >= na) && args[i].Name != "" {
1148-
stmtArgs = append(stmtArgs, args[i])
1149-
}
1150-
}
1151-
for i := range stmtArgs {
1152-
stmtArgs[i].Ordinal = i + 1
1153-
}
1154-
rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs)
1155-
if err != nil && err != driver.ErrSkip {
1156-
s.Close()
1157-
return rows, err
1184+
s := SQLiteStmt{c: c, cls: true}
1185+
p := stringData(query)
1186+
var paramCount C.int
1187+
rv := C._sqlite3_prepare_query(c.db, (*C.char)(unsafe.Pointer(p)), C.int(len(query)), &s.s, &paramCount)
1188+
if rv != C.SQLITE_OK {
1189+
if rv == C.GO_SQLITE_MULTIPLE_QUERIES {
1190+
return nil, errors.New("query contains multiple SQL statements")
11581191
}
1159-
start += na
1160-
tail := s.(*SQLiteStmt).t
1161-
if tail == "" {
1162-
return rows, nil
1192+
return nil, c.lastError()
1193+
}
1194+
1195+
// The sqlite3_stmt will be nil if the SQL was valid but did not
1196+
// contain a query. For now we're supporting this for the sake of
1197+
// backwards compatibility, but that may change in the future.
1198+
if s.s == nil {
1199+
return closedRows, nil
1200+
}
1201+
1202+
na := int(paramCount)
1203+
if n := len(args); n != na {
1204+
s.finalize()
1205+
if n < na {
1206+
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
11631207
}
1164-
rows.Close()
1165-
s.Close()
1166-
query = tail
1208+
return nil, fmt.Errorf("too many args to execute query: want %d got %d", na, len(args))
11671209
}
1210+
1211+
rows, err := s.query(ctx, args)
1212+
if err != nil && err != driver.ErrSkip {
1213+
s.finalize()
1214+
return rows, err
1215+
}
1216+
return rows, nil
11681217
}
11691218

11701219
// Begin transaction.

sqlite3_test.go

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"math/rand"
2020
"net/url"
2121
"os"
22+
"path/filepath"
2223
"reflect"
2324
"regexp"
2425
"runtime"
@@ -1203,6 +1204,163 @@ func TestQueryer(t *testing.T) {
12031204
}
12041205
}
12051206

1207+
func testQuery(t *testing.T, test func(t *testing.T, db *sql.DB)) {
1208+
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.sqlite3"))
1209+
if err != nil {
1210+
t.Fatal("Failed to open database:", err)
1211+
}
1212+
defer db.Close()
1213+
1214+
_, err = db.Exec(`
1215+
CREATE TABLE FOO (id INTEGER);
1216+
INSERT INTO foo(id) VALUES(?);
1217+
INSERT INTO foo(id) VALUES(?);
1218+
INSERT INTO foo(id) VALUES(?);
1219+
`, 3, 2, 1)
1220+
if err != nil {
1221+
t.Fatal(err)
1222+
}
1223+
1224+
// Capture panic so tests can continue
1225+
defer func() {
1226+
if e := recover(); e != nil {
1227+
buf := make([]byte, 32*1024)
1228+
n := runtime.Stack(buf, false)
1229+
t.Fatalf("\npanic: %v\n\n%s\n", e, buf[:n])
1230+
}
1231+
}()
1232+
test(t, db)
1233+
}
1234+
1235+
func testQueryValues(t *testing.T, query string, args ...interface{}) []interface{} {
1236+
var values []interface{}
1237+
testQuery(t, func(t *testing.T, db *sql.DB) {
1238+
rows, err := db.Query(query, args...)
1239+
if err != nil {
1240+
t.Fatal(err)
1241+
}
1242+
if rows == nil {
1243+
t.Fatal("nil rows")
1244+
}
1245+
for i := 0; rows.Next(); i++ {
1246+
if i > 1_000 {
1247+
t.Fatal("To many iterations of rows.Next():", i)
1248+
}
1249+
var v interface{}
1250+
if err := rows.Scan(&v); err != nil {
1251+
t.Fatal(err)
1252+
}
1253+
values = append(values, v)
1254+
}
1255+
if err := rows.Err(); err != nil {
1256+
t.Fatal(err)
1257+
}
1258+
if err := rows.Close(); err != nil {
1259+
t.Fatal(err)
1260+
}
1261+
})
1262+
return values
1263+
}
1264+
1265+
func TestQuery(t *testing.T) {
1266+
queries := []struct {
1267+
query string
1268+
args []interface{}
1269+
}{
1270+
{"SELECT id FROM foo ORDER BY id;", nil},
1271+
{"SELECT id FROM foo WHERE id != ? ORDER BY id;", []interface{}{4}},
1272+
{"SELECT id FROM foo WHERE id IN (?, ?, ?) ORDER BY id;", []interface{}{1, 2, 3}},
1273+
1274+
// Comments
1275+
{"SELECT id FROM foo ORDER BY id; -- comment", nil},
1276+
{"SELECT id FROM foo ORDER BY id -- comment", nil}, // Not terminated
1277+
{"SELECT id FROM foo ORDER BY id;\n -- comment\n", nil},
1278+
{
1279+
`-- FOO
1280+
SELECT id FROM foo ORDER BY id; -- BAR
1281+
/* BAZ */`,
1282+
nil,
1283+
},
1284+
}
1285+
want := []interface{}{
1286+
int64(1),
1287+
int64(2),
1288+
int64(3),
1289+
}
1290+
for _, q := range queries {
1291+
t.Run("", func(t *testing.T) {
1292+
got := testQueryValues(t, q.query, q.args...)
1293+
if !reflect.DeepEqual(got, want) {
1294+
t.Fatalf("Query(%q, %v) = %v; want: %v", q.query, q.args, got, want)
1295+
}
1296+
})
1297+
}
1298+
}
1299+
1300+
func TestQueryNoSQL(t *testing.T) {
1301+
got := testQueryValues(t, "")
1302+
if got != nil {
1303+
t.Fatalf("Query(%q, %v) = %v; want: %v", "", nil, got, nil)
1304+
}
1305+
}
1306+
1307+
func testQueryError(t *testing.T, query string, args ...interface{}) {
1308+
testQuery(t, func(t *testing.T, db *sql.DB) {
1309+
rows, err := db.Query(query, args...)
1310+
if err == nil {
1311+
t.Error("Expected an error got:", err)
1312+
}
1313+
if rows != nil {
1314+
t.Error("Returned rows should be nil on error!")
1315+
// Attempt to iterate over rows to make sure they don't panic.
1316+
for i := 0; rows.Next(); i++ {
1317+
if i > 1_000 {
1318+
t.Fatal("To many iterations of rows.Next():", i)
1319+
}
1320+
}
1321+
if err := rows.Err(); err != nil {
1322+
t.Error(err)
1323+
}
1324+
rows.Close()
1325+
}
1326+
})
1327+
}
1328+
1329+
func TestQueryNotEnoughArgs(t *testing.T) {
1330+
testQueryError(t, "SELECT FROM foo WHERE id = ? OR id = ?);", 1)
1331+
}
1332+
1333+
func TestQueryTooManyArgs(t *testing.T) {
1334+
// TODO: test error message / kind
1335+
testQueryError(t, "SELECT FROM foo WHERE id = ?);", 1, 2)
1336+
}
1337+
1338+
func TestQueryMultipleStatements(t *testing.T) {
1339+
testQueryError(t, "SELECT 1; SELECT 2;")
1340+
testQueryError(t, "SELECT 1; SELECT 2; SELECT 3;")
1341+
testQueryError(t, "SELECT 1; ; SELECT 2;") // Empty statement in between
1342+
testQueryError(t, "SELECT 1; FOOBAR 2;") // Error in second statement
1343+
1344+
// Test that multiple trailing semicolons (";;") are not an error
1345+
noError := func(t *testing.T, query string, args ...any) {
1346+
testQuery(t, func(t *testing.T, db *sql.DB) {
1347+
var n int64
1348+
if err := db.QueryRow(query, args...).Scan(&n); err != nil {
1349+
t.Fatal(err)
1350+
}
1351+
if n != 1 {
1352+
t.Fatalf("got: %d want: %d", n, 1)
1353+
}
1354+
})
1355+
}
1356+
noError(t, "SELECT 1; ;")
1357+
noError(t, "SELECT ?; ;", 1)
1358+
}
1359+
1360+
func TestQueryInvalidTable(t *testing.T) {
1361+
testQueryError(t, "SELECT COUNT(*) FROM does_not_exist;")
1362+
}
1363+
12061364
func TestStress(t *testing.T) {
12071365
tempFilename := TempFilename(t)
12081366
defer os.Remove(tempFilename)
@@ -2180,6 +2338,7 @@ var benchmarks = []testing.InternalBenchmark{
21802338
{Name: "BenchmarkExecContextStep", F: benchmarkExecContextStep},
21812339
{Name: "BenchmarkExecTx", F: benchmarkExecTx},
21822340
{Name: "BenchmarkQuery", F: benchmarkQuery},
2341+
{Name: "BenchmarkQuerySimple", F: benchmarkQuerySimple},
21832342
{Name: "BenchmarkQueryContext", F: benchmarkQueryContext},
21842343
{Name: "BenchmarkParams", F: benchmarkParams},
21852344
{Name: "BenchmarkStmt", F: benchmarkStmt},
@@ -2619,6 +2778,15 @@ func benchmarkQuery(b *testing.B) {
26192778
}
26202779
}
26212780

2781+
func benchmarkQuerySimple(b *testing.B) {
2782+
for i := 0; i < b.N; i++ {
2783+
var n int
2784+
if err := db.QueryRow("select 1;").Scan(&n); err != nil {
2785+
panic(err)
2786+
}
2787+
}
2788+
}
2789+
26222790
// benchmarkQueryContext is benchmark for QueryContext
26232791
func benchmarkQueryContext(b *testing.B) {
26242792
const createTableStmt = `

0 commit comments

Comments
 (0)