Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
with:
fetch-depth: 1

- uses: actions/cache@v2
- uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
Expand Down
127 changes: 80 additions & 47 deletions sqlhooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ type Conn struct {
}

func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
return conn.prepareContext(ctx, query)
}

func (conn *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) {
var (
stmt driver.Stmt
err error
Expand All @@ -93,7 +97,7 @@ func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt
}

if err != nil {
return stmt, err
return nil, err
}

return &Stmt{stmt, conn.hooks, query}, nil
Expand Down Expand Up @@ -139,21 +143,39 @@ func (conn *ExecerContext) execContext(ctx context.Context, query string, args [
}

func (conn *ExecerContext) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
return execWithHooks(ctx, query, args, conn.hooks, func(ctx context.Context) (driver.Result, error) {
results, err := conn.execContext(ctx, query, args)
if err == nil || !errors.Is(err, driver.ErrSkip) {
return results, err
}
// If driver.ErrSkip is returned, we fall back to using Prepare + Statement to handle the query.
// We need to avoid executing the hooks twice since they were already run in ExecContext.
// This matches the behavior in database/sql when ExecContext returns ErrSkip.
stmt, err := conn.prepareContext(ctx, query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.execContext(ctx, args)
})
}

func execWithHooks(ctx context.Context, query string, args []driver.NamedValue, hooks Hooks, execer func(context.Context) (driver.Result, error)) (driver.Result, error) {
var err error

list := namedToInterface(args)

// Exec `Before` Hooks
if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil {
if ctx, err = hooks.Before(ctx, query, list...); err != nil {
return nil, err
}

results, err := conn.execContext(ctx, query, args)
results, err := execer(ctx)
if err != nil {
return results, handlerErr(ctx, conn.hooks, err, query, list...)
return results, handlerErr(ctx, hooks, err, query, list...)
}

if _, err := conn.hooks.After(ctx, query, list...); err != nil {
if _, err := hooks.After(ctx, query, list...); err != nil {
return nil, err
}

Expand Down Expand Up @@ -201,21 +223,43 @@ func (conn *QueryerContext) queryContext(ctx context.Context, query string, args
}

func (conn *QueryerContext) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
return queryWithHooks(ctx, query, args, conn.hooks, func(ctx context.Context) (driver.Rows, error) {
rows, err := conn.queryContext(ctx, query, args)
if err == nil || !errors.Is(err, driver.ErrSkip) {
return rows, err
}
// If driver.ErrSkip is returned, we fall back to using Prepare + Statement to handle the query.
// We need to avoid executing the hooks twice since they were already run in QueryContext.
// This matches the behavior in database/sql when QueryContext returns ErrSkip.
stmt, err := conn.prepareContext(ctx, query)
if err != nil {
return nil, err
}
rows, err = stmt.queryContext(ctx, args)
if err != nil {
_ = stmt.Close()
return nil, err
}
return &rowsWrapper{rows: rows, closeStmt: stmt}, nil
})
}

func queryWithHooks(ctx context.Context, query string, args []driver.NamedValue, hooks Hooks, queryer func(context.Context) (driver.Rows, error)) (driver.Rows, error) {
var err error

list := namedToInterface(args)

// Query `Before` Hooks
if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil {
if ctx, err = hooks.Before(ctx, query, list...); err != nil {
return nil, err
}

results, err := conn.queryContext(ctx, query, args)
results, err := queryer(ctx)
if err != nil {
return results, handlerErr(ctx, conn.hooks, err, query, list...)
return results, handlerErr(ctx, hooks, err, query, list...)
}

if _, err := conn.hooks.After(ctx, query, list...); err != nil {
if _, err := hooks.After(ctx, query, list...); err != nil {
return nil, err
}

Expand Down Expand Up @@ -264,25 +308,9 @@ func (stmt *Stmt) execContext(ctx context.Context, args []driver.NamedValue) (dr
}

func (stmt *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
var err error

list := namedToInterface(args)

// Exec `Before` Hooks
if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil {
return nil, err
}

results, err := stmt.execContext(ctx, args)
if err != nil {
return results, handlerErr(ctx, stmt.hooks, err, stmt.query, list...)
}

if _, err := stmt.hooks.After(ctx, stmt.query, list...); err != nil {
return nil, err
}

return results, err
return execWithHooks(ctx, stmt.query, args, stmt.hooks, func(ctx context.Context) (driver.Result, error) {
return stmt.execContext(ctx, args)
})
}

func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
Expand All @@ -298,25 +326,9 @@ func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (d
}

func (stmt *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
var err error

list := namedToInterface(args)

// Exec Before Hooks
if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil {
return nil, err
}

rows, err := stmt.queryContext(ctx, args)
if err != nil {
return rows, handlerErr(ctx, stmt.hooks, err, stmt.query, list...)
}

if _, err := stmt.hooks.After(ctx, stmt.query, list...); err != nil {
return nil, err
}

return rows, err
return queryWithHooks(ctx, stmt.query, args, stmt.hooks, func(ctx context.Context) (driver.Rows, error) {
return stmt.queryContext(ctx, args)
})
}

func (stmt *Stmt) Close() error { return stmt.Stmt.Close() }
Expand Down Expand Up @@ -350,6 +362,27 @@ func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
return dargs, nil
}

type rowsWrapper struct {
rows driver.Rows
closeStmt driver.Stmt // if non-nil, statement to Close on close
}

func (r *rowsWrapper) Close() error {
err := r.rows.Close()
if r.closeStmt != nil {
_ = r.closeStmt.Close()
}
return err
}

func (r *rowsWrapper) Columns() []string {
return r.rows.Columns()
}

func (r *rowsWrapper) Next(dest []driver.Value) error {
return r.rows.Next(dest)
}

/*
type hooks struct {
}
Expand Down
40 changes: 20 additions & 20 deletions sqlhooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,62 +68,62 @@ func newSuite(t *testing.T, driver driver.Driver, dsn string) *suite {
}

func (s *suite) TestHooksExecution(t *testing.T, query string, args ...interface{}) {
var before, after bool
var beforeCount, afterCount int

s.hooks.before = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) {
before = true
beforeCount++
return ctx, nil
}
s.hooks.after = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) {
after = true
afterCount++
return ctx, nil
}

t.Run("Query", func(t *testing.T) {
before, after = false, false
beforeCount, afterCount = 0, 0
_, err := s.db.Query(query, args...)
require.NoError(t, err)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query)
assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query)
})

t.Run("QueryContext", func(t *testing.T) {
before, after = false, false
beforeCount, afterCount = 0, 0
_, err := s.db.QueryContext(context.Background(), query, args...)
require.NoError(t, err)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query)
assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query)
})

t.Run("Exec", func(t *testing.T) {
before, after = false, false
beforeCount, afterCount = 0, 0
_, err := s.db.Exec(query, args...)
require.NoError(t, err)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query)
assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query)
})

t.Run("ExecContext", func(t *testing.T) {
before, after = false, false
beforeCount, afterCount = 0, 0
_, err := s.db.ExecContext(context.Background(), query, args...)
require.NoError(t, err)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query)
assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query)
})

t.Run("Statements", func(t *testing.T) {
before, after = false, false
beforeCount, afterCount = 0, 0
stmt, err := s.db.Prepare(query)
require.NoError(t, err)

// Hooks just run when the stmt is executed (Query or Exec)
assert.False(t, before, "Before Hook run before execution: "+query)
assert.False(t, after, "After Hook run before execution: "+query)
assert.Equal(t, 0, beforeCount, "Before Hook run before execution: "+query)
assert.Equal(t, 0, afterCount, "After Hook run before execution: "+query)

_, err = stmt.Query(args...)
require.NoError(t, err)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query)
assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query)
})
}

Expand Down