@@ -81,6 +81,10 @@ type Conn struct {
8181}
8282
8383func (conn * Conn ) PrepareContext (ctx context.Context , query string ) (driver.Stmt , error ) {
84+ return conn .prepareContext (ctx , query )
85+ }
86+
87+ func (conn * Conn ) prepareContext (ctx context.Context , query string ) (* Stmt , error ) {
8488 var (
8589 stmt driver.Stmt
8690 err error
@@ -93,7 +97,7 @@ func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt
9397 }
9498
9599 if err != nil {
96- return stmt , err
100+ return nil , err
97101 }
98102
99103 return & Stmt {stmt , conn .hooks , query }, nil
@@ -139,21 +143,38 @@ func (conn *ExecerContext) execContext(ctx context.Context, query string, args [
139143}
140144
141145func (conn * ExecerContext ) ExecContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
146+ return execWithHooks (ctx , query , args , conn .hooks , func (ctx context.Context ) (driver.Result , error ) {
147+ results , err := conn .execContext (ctx , query , args )
148+ if err == nil || ! errors .Is (err , driver .ErrSkip ) {
149+ return results , err
150+ }
151+ // If driver.ErrSkip is returned, we fall back to using Prepare + Statement to handle the query.
152+ // We need to avoid executing the hooks twice since they were already run in ExecContext.
153+ // This matches the behavior in database/sql when ExecContext returns ErrSkip.
154+ stmt , err := conn .prepareContext (ctx , query )
155+ if err != nil {
156+ return nil , err
157+ }
158+ return stmt .execContext (ctx , args )
159+ })
160+ }
161+
162+ func execWithHooks (ctx context.Context , query string , args []driver.NamedValue , hooks Hooks , execer func (context.Context ) (driver.Result , error )) (driver.Result , error ) {
142163 var err error
143164
144165 list := namedToInterface (args )
145166
146167 // Exec `Before` Hooks
147- if ctx , err = conn . hooks .Before (ctx , query , list ... ); err != nil {
168+ if ctx , err = hooks .Before (ctx , query , list ... ); err != nil {
148169 return nil , err
149170 }
150171
151- results , err := conn . execContext (ctx , query , args )
172+ results , err := execer (ctx )
152173 if err != nil {
153- return results , handlerErr (ctx , conn . hooks , err , query , list ... )
174+ return results , handlerErr (ctx , hooks , err , query , list ... )
154175 }
155176
156- if _ , err := conn . hooks .After (ctx , query , list ... ); err != nil {
177+ if _ , err := hooks .After (ctx , query , list ... ); err != nil {
157178 return nil , err
158179 }
159180
@@ -201,21 +222,38 @@ func (conn *QueryerContext) queryContext(ctx context.Context, query string, args
201222}
202223
203224func (conn * QueryerContext ) QueryContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Rows , error ) {
225+ return queryWithHooks (ctx , query , args , conn .hooks , func (ctx context.Context ) (driver.Rows , error ) {
226+ results , err := conn .queryContext (ctx , query , args )
227+ if err == nil || ! errors .Is (err , driver .ErrSkip ) {
228+ return results , err
229+ }
230+ // If driver.ErrSkip is returned, we fall back to using Prepare + Statement to handle the query.
231+ // We need to avoid executing the hooks twice since they were already run in QueryContext.
232+ // This matches the behavior in database/sql when QueryContext returns ErrSkip.
233+ stmt , err := conn .prepareContext (ctx , query )
234+ if err != nil {
235+ return nil , err
236+ }
237+ return stmt .queryContext (ctx , args )
238+ })
239+ }
240+
241+ func queryWithHooks (ctx context.Context , query string , args []driver.NamedValue , hooks Hooks , queryer func (context.Context ) (driver.Rows , error )) (driver.Rows , error ) {
204242 var err error
205243
206244 list := namedToInterface (args )
207245
208246 // Query `Before` Hooks
209- if ctx , err = conn . hooks .Before (ctx , query , list ... ); err != nil {
247+ if ctx , err = hooks .Before (ctx , query , list ... ); err != nil {
210248 return nil , err
211249 }
212250
213- results , err := conn . queryContext (ctx , query , args )
251+ results , err := queryer (ctx )
214252 if err != nil {
215- return results , handlerErr (ctx , conn . hooks , err , query , list ... )
253+ return results , handlerErr (ctx , hooks , err , query , list ... )
216254 }
217255
218- if _ , err := conn . hooks .After (ctx , query , list ... ); err != nil {
256+ if _ , err := hooks .After (ctx , query , list ... ); err != nil {
219257 return nil , err
220258 }
221259
@@ -264,25 +302,9 @@ func (stmt *Stmt) execContext(ctx context.Context, args []driver.NamedValue) (dr
264302}
265303
266304func (stmt * Stmt ) ExecContext (ctx context.Context , args []driver.NamedValue ) (driver.Result , error ) {
267- var err error
268-
269- list := namedToInterface (args )
270-
271- // Exec `Before` Hooks
272- if ctx , err = stmt .hooks .Before (ctx , stmt .query , list ... ); err != nil {
273- return nil , err
274- }
275-
276- results , err := stmt .execContext (ctx , args )
277- if err != nil {
278- return results , handlerErr (ctx , stmt .hooks , err , stmt .query , list ... )
279- }
280-
281- if _ , err := stmt .hooks .After (ctx , stmt .query , list ... ); err != nil {
282- return nil , err
283- }
284-
285- return results , err
305+ return execWithHooks (ctx , stmt .query , args , stmt .hooks , func (ctx context.Context ) (driver.Result , error ) {
306+ return stmt .execContext (ctx , args )
307+ })
286308}
287309
288310func (stmt * Stmt ) queryContext (ctx context.Context , args []driver.NamedValue ) (driver.Rows , error ) {
@@ -298,25 +320,9 @@ func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (d
298320}
299321
300322func (stmt * Stmt ) QueryContext (ctx context.Context , args []driver.NamedValue ) (driver.Rows , error ) {
301- var err error
302-
303- list := namedToInterface (args )
304-
305- // Exec Before Hooks
306- if ctx , err = stmt .hooks .Before (ctx , stmt .query , list ... ); err != nil {
307- return nil , err
308- }
309-
310- rows , err := stmt .queryContext (ctx , args )
311- if err != nil {
312- return rows , handlerErr (ctx , stmt .hooks , err , stmt .query , list ... )
313- }
314-
315- if _ , err := stmt .hooks .After (ctx , stmt .query , list ... ); err != nil {
316- return nil , err
317- }
318-
319- return rows , err
323+ return queryWithHooks (ctx , stmt .query , args , stmt .hooks , func (ctx context.Context ) (driver.Rows , error ) {
324+ return stmt .queryContext (ctx , args )
325+ })
320326}
321327
322328func (stmt * Stmt ) Close () error { return stmt .Stmt .Close () }
0 commit comments