Skip to content

Commit

Permalink
use extended error codes and handle closed db connections in lastError
Browse files Browse the repository at this point in the history
This commit changes the error handling logic so that it respects the
offending result code (instead of only relying on sqlite3_errcode) and
changes the db connection to always report the extended result code
(which eliminates the need to call sqlite3_extended_errcode). These
changes make it possible to correctly and safely handle errors when the
underlying db connection has been closed.
  • Loading branch information
charlievieth committed Dec 16, 2024
1 parent 69c42ee commit d2d4030
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 57 deletions.
5 changes: 4 additions & 1 deletion backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ func (destConn *SQLiteConn) Backup(dest string, srcConn *SQLiteConn, src string)
runtime.SetFinalizer(bb, (*SQLiteBackup).Finish)
return bb, nil
}
return nil, destConn.lastError()
if destConn.db != nil {
return nil, destConn.lastError(int(C.sqlite3_extended_errcode(destConn.db)))
}
return nil, Error{Code: 1, ExtendedCode: 1, err: "backup: destination connection is nil"}
}

// Step to backs up for one step. Calls the underlying `sqlite3_backup_step`
Expand Down
24 changes: 21 additions & 3 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@ package sqlite3
#endif
*/
import "C"
import "syscall"
import (
"sync"
"syscall"
)

// ErrNo inherit errno.
type ErrNo int

// ErrNoMask is mask code.
const ErrNoMask C.int = 0xff
const ErrNoMask = 0xff

// ErrNoExtended is extended errno.
type ErrNoExtended int
Expand Down Expand Up @@ -85,7 +88,7 @@ func (err Error) Error() string {
if err.err != "" {
str = err.err
} else {
str = C.GoString(C.sqlite3_errstr(C.int(err.Code)))
str = errorString(int(err.Code))
}
if err.SystemErrno != 0 {
str += ": " + err.SystemErrno.Error()
Expand Down Expand Up @@ -148,3 +151,18 @@ var (
ErrNoticeRecoverRollback = ErrNotice.Extend(2)
ErrWarningAutoIndex = ErrWarning.Extend(1)
)

var errStrCache sync.Map // int => string

// errorString returns the result of sqlite3_errstr for result code
// rv which may be cached.
func errorString(rv int) string {
if v, ok := errStrCache.Load(rv); ok {
return v.(string)
}
s := C.GoString(C.sqlite3_errstr(C.int(rv)))
if v, loaded := errStrCache.LoadOrStore(rv, s); loaded {
return v.(string)
}
return s
}
59 changes: 35 additions & 24 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ func (c *SQLiteConn) RegisterCollation(name string, cmp func(string, string) int
defer C.free(unsafe.Pointer(cname))
rv := C.sqlite3_create_collation(c.db, cname, C.SQLITE_UTF8, handle, (*[0]byte)(unsafe.Pointer(C.compareTrampoline)))
if rv != C.SQLITE_OK {
return c.lastError()
return c.lastError(int(rv))
}
return nil
}
Expand Down Expand Up @@ -675,7 +675,7 @@ func (c *SQLiteConn) RegisterFunc(name string, impl any, pure bool) error {
}
rv := sqlite3CreateFunction(c.db, cname, C.int(numArgs), C.int(opts), newHandle(c, &fi), C.callbackTrampoline, nil, nil)
if rv != C.SQLITE_OK {
return c.lastError()
return c.lastError(int(rv))
}
return nil
}
Expand Down Expand Up @@ -804,7 +804,7 @@ func (c *SQLiteConn) RegisterAggregator(name string, impl any, pure bool) error
}
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
if rv != C.SQLITE_OK {
return c.lastError()
return c.lastError(int(rv))
}
return nil
}
Expand All @@ -816,32 +816,38 @@ func (c *SQLiteConn) AutoCommit() bool {
return int(C.sqlite3_get_autocommit(c.db)) != 0
}

func (c *SQLiteConn) lastError() error {
return lastError(c.db)
func (c *SQLiteConn) lastError(rv int) error {
return lastError(c.db, rv)
}

// Note: may be called with db == nil
func lastError(db *C.sqlite3) error {
rv := C.sqlite3_errcode(db) // returns SQLITE_NOMEM if db == nil
if rv == C.SQLITE_OK {
func lastError(db *C.sqlite3, rv int) error {
if rv == SQLITE_OK {
return nil
}
extrv := C.sqlite3_extended_errcode(db) // returns SQLITE_NOMEM if db == nil
errStr := C.GoString(C.sqlite3_errmsg(db)) // returns "out of memory" if db == nil
extrv := rv
// Convert the extended result code to a basic result code.
rv &= ErrNoMask

// https://www.sqlite.org/c3ref/system_errno.html
// sqlite3_system_errno is only meaningful if the error code was SQLITE_CANTOPEN,
// or it was SQLITE_IOERR and the extended code was not SQLITE_IOERR_NOMEM
var systemErrno syscall.Errno
if rv == C.SQLITE_CANTOPEN || (rv == C.SQLITE_IOERR && extrv != C.SQLITE_IOERR_NOMEM) {
if db != nil && (rv == C.SQLITE_CANTOPEN ||
(rv == C.SQLITE_IOERR && extrv != C.SQLITE_IOERR_NOMEM)) {
systemErrno = syscall.Errno(C.sqlite3_system_errno(db))
}

var msg string
if db != nil {
msg = C.GoString(C.sqlite3_errmsg(db))
} else {
msg = errorString(extrv)
}
return Error{
Code: ErrNo(rv),
ExtendedCode: ErrNoExtended(extrv),
SystemErrno: systemErrno,
err: errStr,
err: msg,
}
}

Expand Down Expand Up @@ -1467,7 +1473,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if rv != 0 {
// Save off the error _before_ closing the database.
// This is safe even if db is nil.
err := lastError(db)
err := lastError(db, int(rv))
if db != nil {
C.sqlite3_close_v2(db)
}
Expand All @@ -1476,13 +1482,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if db == nil {
return nil, errors.New("sqlite succeeded without returning a database")
}
rv = C.sqlite3_extended_result_codes(db, 1)
if rv != SQLITE_OK {
C.sqlite3_close_v2(db)
return nil, lastError(db, int(rv))
}

exec := func(s string) error {
cs := C.CString(s)
rv := C.sqlite3_exec(db, cs, nil, nil, nil)
C.free(unsafe.Pointer(cs))
if rv != C.SQLITE_OK {
return lastError(db)
return lastError(db, int(rv))
}
return nil
}
Expand Down Expand Up @@ -1791,7 +1802,7 @@ func (c *SQLiteConn) Close() (err error) {
runtime.SetFinalizer(c, nil)
rv := C.sqlite3_close_v2(c.db)
if rv != C.SQLITE_OK {
err = c.lastError()
err = lastError(nil, int(rv))
}
deleteHandles(c)
c.db = nil
Expand All @@ -1810,7 +1821,7 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er
var tail *C.char
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s, &tail)
if rv != C.SQLITE_OK {
return nil, c.lastError()
return nil, c.lastError(int(rv))
}
var t string
if tail != nil && *tail != '\000' {
Expand Down Expand Up @@ -1882,7 +1893,7 @@ func (c *SQLiteConn) SetFileControlInt(dbName string, op int, arg int) error {
cArg := C.int(arg)
rv := C.sqlite3_file_control(c.db, cDBName, C.int(op), unsafe.Pointer(&cArg))
if rv != C.SQLITE_OK {
return c.lastError()
return c.lastError(int(rv))
}
return nil
}
Expand All @@ -1902,7 +1913,7 @@ func (s *SQLiteStmt) Close() error {
runtime.SetFinalizer(s, nil)
rv := C.sqlite3_finalize(stmt)
if rv != C.SQLITE_OK {
return conn.lastError()
return conn.lastError(int(rv))
}
return nil
}
Expand All @@ -1917,7 +1928,7 @@ var placeHolder = []byte{0}
func (s *SQLiteStmt) bind(args []driver.NamedValue) error {
rv := C.sqlite3_reset(s.s)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return s.c.lastError()
return s.c.lastError(int(rv))
}

bindIndices := make([][3]int, len(args))
Expand Down Expand Up @@ -1975,7 +1986,7 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error {
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
if rv != C.SQLITE_OK {
return s.c.lastError()
return s.c.lastError(int(rv))
}
}
}
Expand Down Expand Up @@ -2087,7 +2098,7 @@ func (s *SQLiteStmt) execSync(args []driver.NamedValue) (driver.Result, error) {
var rowid, changes C.longlong
rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
err := s.c.lastError()
err := s.c.lastError(int(rv))
C.sqlite3_reset(s.s)
C.sqlite3_clear_bindings(s.s)
return nil, err
Expand Down Expand Up @@ -2118,7 +2129,7 @@ func (rc *SQLiteRows) Close() error {
rv := C.sqlite3_reset(rc.s.s)
if rv != C.SQLITE_OK {
rc.s.mu.Unlock()
return rc.s.c.lastError()
return rc.s.c.lastError(int(rv))
}
rc.s.mu.Unlock()
rc.s = nil
Expand Down Expand Up @@ -2197,7 +2208,7 @@ func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error {
if rv != C.SQLITE_ROW {
rv = C.sqlite3_reset(rc.s.s)
if rv != C.SQLITE_OK {
return rc.s.c.lastError()
return rc.s.c.lastError(int(rv))
}
return nil
}
Expand Down
20 changes: 7 additions & 13 deletions sqlite3_opt_unlock_notify.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
// license that can be found in the LICENSE file.

#ifdef SQLITE_ENABLE_UNLOCK_NOTIFY
#include <stdio.h>
#include "sqlite3-binding.h"

extern int unlock_notify_wait(sqlite3 *db);

static inline int is_locked(int rv) {
return rv == SQLITE_LOCKED || rv == SQLITE_LOCKED_SHAREDCACHE;
}

int
_sqlite3_step_blocking(sqlite3_stmt *stmt)
{
Expand All @@ -18,10 +21,7 @@ _sqlite3_step_blocking(sqlite3_stmt *stmt)
db = sqlite3_db_handle(stmt);
for (;;) {
rv = sqlite3_step(stmt);
if (rv != SQLITE_LOCKED) {
break;
}
if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) {
if (!is_locked(rv)) {
break;
}
rv = unlock_notify_wait(db);
Expand All @@ -43,10 +43,7 @@ _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* chan
db = sqlite3_db_handle(stmt);
for (;;) {
rv = sqlite3_step(stmt);
if (rv!=SQLITE_LOCKED) {
break;
}
if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) {
if (!is_locked(rv)) {
break;
}
rv = unlock_notify_wait(db);
Expand All @@ -68,10 +65,7 @@ _sqlite3_prepare_v2_blocking(sqlite3 *db, const char *zSql, int nBytes, sqlite3_

for (;;) {
rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
if (rv!=SQLITE_LOCKED) {
break;
}
if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) {
if (!is_locked(rv)) {
break;
}
rv = unlock_notify_wait(db);
Expand Down
22 changes: 14 additions & 8 deletions sqlite3_opt_unlock_notify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ func TestUnlockNotify(t *testing.T) {
wg.Add(1)
timer := time.NewTimer(500 * time.Millisecond)
go func() {
defer wg.Done()
<-timer.C
err := tx.Commit()
if err != nil {
t.Fatal("Failed to commit transaction:", err)
}
wg.Done()
}()

rows, err := db.Query("SELECT count(*) from foo")
Expand Down Expand Up @@ -111,33 +111,39 @@ func TestUnlockNotifyMany(t *testing.T) {
wg.Add(1)
timer := time.NewTimer(500 * time.Millisecond)
go func() {
defer wg.Done()
<-timer.C
err := tx.Commit()
if err != nil {
t.Fatal("Failed to commit transaction:", err)
}
wg.Done()
}()

const concurrentQueries = 1000
wg.Add(concurrentQueries)
for i := 0; i < concurrentQueries; i++ {
go func() {
defer wg.Done()
rows, err := db.Query("SELECT count(*) from foo")
if err != nil {
t.Fatal("Unable to query foo table:", err)
t.Error("Unable to query foo table:", err)
return
}

if rows.Next() {
var count int
if err := rows.Scan(&count); err != nil {
t.Fatal("Failed to Scan rows", err)
t.Error("Failed to Scan rows", err)
return
}
if count != 1 {
t.Errorf("count=%d want=%d", count, 1)
}
}
if err := rows.Err(); err != nil {
t.Fatal("Failed at the call to Next:", err)
t.Error("Failed at the call to Next:", err)
return
}
wg.Done()
}()
}
wg.Wait()
Expand Down Expand Up @@ -177,16 +183,17 @@ func TestUnlockNotifyDeadlock(t *testing.T) {
wg.Add(1)
timer := time.NewTimer(500 * time.Millisecond)
go func() {
defer wg.Done()
<-timer.C
err := tx.Commit()
if err != nil {
t.Fatal("Failed to commit transaction:", err)
}
wg.Done()
}()

wg.Add(1)
go func() {
defer wg.Done()
tx2, err := db.Begin()
if err != nil {
t.Fatal("Failed to begin transaction:", err)
Expand All @@ -201,7 +208,6 @@ func TestUnlockNotifyDeadlock(t *testing.T) {
if err != nil {
t.Fatal("Failed to commit transaction:", err)
}
wg.Done()
}()

rows, err := tx.Query("SELECT count(*) from foo")
Expand Down
Loading

0 comments on commit d2d4030

Please sign in to comment.