Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions hooks/base.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package hooks

import "context"

type Base struct {
}

func (b *Base) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, nil
}

func (b *Base) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, nil
}
42 changes: 42 additions & 0 deletions hooks/safetyhooks/safetyhooks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package safetyhooks

import (
"database/sql/driver"
"fmt"
"runtime"

"github.com/gchaincl/sqlhooks/v2/hooks"
)

type Hook struct {
hooks.Base
}

func New() *Hook {
return &Hook{}
}

// safeRows wrap a driver.Rows interface in order to implement Sharp-Edged
// Finalizers based on https://crawshaw.io/blog/sharp-edged-finalizers.
type safeRows struct {
driver.Rows
}

func (s *safeRows) Close() {
runtime.SetFinalizer(s, nil)
s.Rows.Close()
}

func doPanic() {
_, file, line, _ := runtime.Caller(1)
panic(fmt.Sprintf("%s:%d: row not closed", file, line))
}

func (h *Hook) Rows(r driver.Rows) driver.Rows {
s := &safeRows{r}
runtime.SetFinalizer(s, func(*safeRows) {
doPanic()
})

return r
}
41 changes: 41 additions & 0 deletions hooks/safetyhooks/safetyhooks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package safetyhooks

import (
"database/sql"
"testing"

"github.com/gchaincl/sqlhooks/v2"
"github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/require"
)

func setupTestDB(t *testing.T, hooks sqlhooks.Hooks) *sql.DB {
var (
err error
name = "final"
)

sql.Register(name, sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, hooks))
db, err := sql.Open(name, ":memory:")
require.NoError(t, err)

_, err = db.Exec("CREATE TABLE test(id int)")
require.NoError(t, err)

_, err = db.Exec("INSERT INTO test VALUES(1)")
require.NoError(t, err)

return db
}

func doQuery(db *sql.DB, query string) (*sql.Rows, error) {
return db.Query(query)
}

func TestFinalizers(t *testing.T) {
hooks := New()
db := setupTestDB(t, hooks)

_, err := doQuery(db, "SELECT * from test")
require.NoError(t, err)
}
10 changes: 10 additions & 0 deletions sqlhooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ type OnErrorer interface {
OnError(ctx context.Context, err error, query string, args ...interface{}) error
}

// RowsWrapper is an optional interface for Hooks representing the hability of
// wrapper rows.
type RowsWrapper interface {
Rows(r driver.Rows) driver.Rows
}

func handlerErr(ctx context.Context, hooks Hooks, err error, query string, args ...interface{}) error {
h, ok := hooks.(OnErrorer)
if !ok {
Expand Down Expand Up @@ -219,6 +225,10 @@ func (conn *QueryerContext) QueryContext(ctx context.Context, query string, args
return nil, err
}

if w, ok := conn.hooks.(RowsWrapper); ok {
results = w.Rows(results)
}

return results, err
}

Expand Down