Skip to content

Commit eb4e59f

Browse files
committed
fix: handle driver.ErrSkip to avoid duplicate hooks execution with MySQL driver
When InterpolateParams=false is set in MySQL driver, it returns driver.ErrSkip which causes the SQL package to fall back to prepared statements, resulting in hooks being executed twice. This change handles driver.ErrSkip internally to ensure hooks are only executed once per logical operation.
1 parent 7875602 commit eb4e59f

File tree

2 files changed

+73
-67
lines changed

2 files changed

+73
-67
lines changed

sqlhooks.go

Lines changed: 53 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ type Conn struct {
8181
}
8282

8383
func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
84+
return conn.prepareContext(ctx, query)
85+
}
86+
87+
func (conn *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) {
8488
var (
8589
stmt driver.Stmt
8690
err error
@@ -93,7 +97,7 @@ func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt
9397
}
9498

9599
if err != nil {
96-
return stmt, err
100+
return nil, err
97101
}
98102

99103
return &Stmt{stmt, conn.hooks, query}, nil
@@ -139,21 +143,38 @@ func (conn *ExecerContext) execContext(ctx context.Context, query string, args [
139143
}
140144

141145
func (conn *ExecerContext) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
146+
return execWithHooks(ctx, query, args, conn.hooks, func(ctx context.Context) (driver.Result, error) {
147+
results, err := conn.execContext(ctx, query, args)
148+
if err == nil || !errors.Is(err, driver.ErrSkip) {
149+
return results, err
150+
}
151+
// If driver.ErrSkip is returned, we fall back to using Prepare + Statement to handle the query.
152+
// We need to avoid executing the hooks twice since they were already run in ExecContext.
153+
// This matches the behavior in database/sql when ExecContext returns ErrSkip.
154+
stmt, err := conn.prepareContext(ctx, query)
155+
if err != nil {
156+
return nil, err
157+
}
158+
return stmt.execContext(ctx, args)
159+
})
160+
}
161+
162+
func execWithHooks(ctx context.Context, query string, args []driver.NamedValue, hooks Hooks, execer func(context.Context) (driver.Result, error)) (driver.Result, error) {
142163
var err error
143164

144165
list := namedToInterface(args)
145166

146167
// Exec `Before` Hooks
147-
if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil {
168+
if ctx, err = hooks.Before(ctx, query, list...); err != nil {
148169
return nil, err
149170
}
150171

151-
results, err := conn.execContext(ctx, query, args)
172+
results, err := execer(ctx)
152173
if err != nil {
153-
return results, handlerErr(ctx, conn.hooks, err, query, list...)
174+
return results, handlerErr(ctx, hooks, err, query, list...)
154175
}
155176

156-
if _, err := conn.hooks.After(ctx, query, list...); err != nil {
177+
if _, err := hooks.After(ctx, query, list...); err != nil {
157178
return nil, err
158179
}
159180

@@ -201,21 +222,38 @@ func (conn *QueryerContext) queryContext(ctx context.Context, query string, args
201222
}
202223

203224
func (conn *QueryerContext) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
225+
return queryWithHooks(ctx, query, args, conn.hooks, func(ctx context.Context) (driver.Rows, error) {
226+
results, err := conn.queryContext(ctx, query, args)
227+
if err == nil || !errors.Is(err, driver.ErrSkip) {
228+
return results, err
229+
}
230+
// If driver.ErrSkip is returned, we fall back to using Prepare + Statement to handle the query.
231+
// We need to avoid executing the hooks twice since they were already run in QueryContext.
232+
// This matches the behavior in database/sql when QueryContext returns ErrSkip.
233+
stmt, err := conn.prepareContext(ctx, query)
234+
if err != nil {
235+
return nil, err
236+
}
237+
return stmt.queryContext(ctx, args)
238+
})
239+
}
240+
241+
func queryWithHooks(ctx context.Context, query string, args []driver.NamedValue, hooks Hooks, queryer func(context.Context) (driver.Rows, error)) (driver.Rows, error) {
204242
var err error
205243

206244
list := namedToInterface(args)
207245

208246
// Query `Before` Hooks
209-
if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil {
247+
if ctx, err = hooks.Before(ctx, query, list...); err != nil {
210248
return nil, err
211249
}
212250

213-
results, err := conn.queryContext(ctx, query, args)
251+
results, err := queryer(ctx)
214252
if err != nil {
215-
return results, handlerErr(ctx, conn.hooks, err, query, list...)
253+
return results, handlerErr(ctx, hooks, err, query, list...)
216254
}
217255

218-
if _, err := conn.hooks.After(ctx, query, list...); err != nil {
256+
if _, err := hooks.After(ctx, query, list...); err != nil {
219257
return nil, err
220258
}
221259

@@ -264,25 +302,9 @@ func (stmt *Stmt) execContext(ctx context.Context, args []driver.NamedValue) (dr
264302
}
265303

266304
func (stmt *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
267-
var err error
268-
269-
list := namedToInterface(args)
270-
271-
// Exec `Before` Hooks
272-
if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil {
273-
return nil, err
274-
}
275-
276-
results, err := stmt.execContext(ctx, args)
277-
if err != nil {
278-
return results, handlerErr(ctx, stmt.hooks, err, stmt.query, list...)
279-
}
280-
281-
if _, err := stmt.hooks.After(ctx, stmt.query, list...); err != nil {
282-
return nil, err
283-
}
284-
285-
return results, err
305+
return execWithHooks(ctx, stmt.query, args, stmt.hooks, func(ctx context.Context) (driver.Result, error) {
306+
return stmt.execContext(ctx, args)
307+
})
286308
}
287309

288310
func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
@@ -298,25 +320,9 @@ func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (d
298320
}
299321

300322
func (stmt *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
301-
var err error
302-
303-
list := namedToInterface(args)
304-
305-
// Exec Before Hooks
306-
if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil {
307-
return nil, err
308-
}
309-
310-
rows, err := stmt.queryContext(ctx, args)
311-
if err != nil {
312-
return rows, handlerErr(ctx, stmt.hooks, err, stmt.query, list...)
313-
}
314-
315-
if _, err := stmt.hooks.After(ctx, stmt.query, list...); err != nil {
316-
return nil, err
317-
}
318-
319-
return rows, err
323+
return queryWithHooks(ctx, stmt.query, args, stmt.hooks, func(ctx context.Context) (driver.Rows, error) {
324+
return stmt.queryContext(ctx, args)
325+
})
320326
}
321327

322328
func (stmt *Stmt) Close() error { return stmt.Stmt.Close() }

sqlhooks_test.go

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -68,62 +68,62 @@ func newSuite(t *testing.T, driver driver.Driver, dsn string) *suite {
6868
}
6969

7070
func (s *suite) TestHooksExecution(t *testing.T, query string, args ...interface{}) {
71-
var before, after bool
71+
var beforeCount, afterCount int
7272

7373
s.hooks.before = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) {
74-
before = true
74+
beforeCount++
7575
return ctx, nil
7676
}
7777
s.hooks.after = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) {
78-
after = true
78+
afterCount++
7979
return ctx, nil
8080
}
8181

8282
t.Run("Query", func(t *testing.T) {
83-
before, after = false, false
83+
beforeCount, afterCount = 0, 0
8484
_, err := s.db.Query(query, args...)
8585
require.NoError(t, err)
86-
assert.True(t, before, "Before Hook did not run for query: "+query)
87-
assert.True(t, after, "After Hook did not run for query: "+query)
86+
assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query)
87+
assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query)
8888
})
8989

9090
t.Run("QueryContext", func(t *testing.T) {
91-
before, after = false, false
91+
beforeCount, afterCount = 0, 0
9292
_, err := s.db.QueryContext(context.Background(), query, args...)
9393
require.NoError(t, err)
94-
assert.True(t, before, "Before Hook did not run for query: "+query)
95-
assert.True(t, after, "After Hook did not run for query: "+query)
94+
assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query)
95+
assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query)
9696
})
9797

9898
t.Run("Exec", func(t *testing.T) {
99-
before, after = false, false
99+
beforeCount, afterCount = 0, 0
100100
_, err := s.db.Exec(query, args...)
101101
require.NoError(t, err)
102-
assert.True(t, before, "Before Hook did not run for query: "+query)
103-
assert.True(t, after, "After Hook did not run for query: "+query)
102+
assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query)
103+
assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query)
104104
})
105105

106106
t.Run("ExecContext", func(t *testing.T) {
107-
before, after = false, false
107+
beforeCount, afterCount = 0, 0
108108
_, err := s.db.ExecContext(context.Background(), query, args...)
109109
require.NoError(t, err)
110-
assert.True(t, before, "Before Hook did not run for query: "+query)
111-
assert.True(t, after, "After Hook did not run for query: "+query)
110+
assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query)
111+
assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query)
112112
})
113113

114114
t.Run("Statements", func(t *testing.T) {
115-
before, after = false, false
115+
beforeCount, afterCount = 0, 0
116116
stmt, err := s.db.Prepare(query)
117117
require.NoError(t, err)
118118

119119
// Hooks just run when the stmt is executed (Query or Exec)
120-
assert.False(t, before, "Before Hook run before execution: "+query)
121-
assert.False(t, after, "After Hook run before execution: "+query)
120+
assert.Equal(t, 0, beforeCount, "Before Hook run before execution: "+query)
121+
assert.Equal(t, 0, afterCount, "After Hook run before execution: "+query)
122122

123123
_, err = stmt.Query(args...)
124124
require.NoError(t, err)
125-
assert.True(t, before, "Before Hook did not run for query: "+query)
126-
assert.True(t, after, "After Hook did not run for query: "+query)
125+
assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query)
126+
assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query)
127127
})
128128
}
129129

0 commit comments

Comments
 (0)