@@ -126,17 +126,19 @@ func getQueryProcess(db *sql.DB, dbName, query string) (*dbProcess, error) {
126126 return longProcess , err
127127}
128128
129- func killQuery (db * sql.DB , dbName , query string , cancel context.CancelFunc ) error {
129+ var expectedKilledErr = fmt .Errorf ("process expected to be killed" )
130+
131+ func killQuery (db * sql.DB , dbName , query string , timeout time.Duration , cancel context.CancelFunc ) error {
130132 process , err := getQueryProcess (db , dbName , query )
131133 if err != nil {
132134 return fmt .Errorf ("failed to get mysql process: %v" , err )
133135 }
134136 cancel ()
135137
136- end := time .Now ().Add (killTimeout )
138+ end := time .Now ().Add (timeout )
137139 for time .Now ().Before (end ) {
138140 if checkProcessExists (dbName , process .ID , db ) {
139- err = fmt . Errorf ( "process %d expected to be killed" , process . ID )
141+ err = expectedKilledErr
140142 time .Sleep (pollTimeout )
141143 } else {
142144 err = nil
@@ -173,7 +175,7 @@ func testCancel(dbt *DBTest, ctx context.Context, cancel context.CancelFunc, que
173175 }()
174176
175177 // it is safe to not use timeouts here since they are inside the killQuery function
176- err = killQuery (dbt .db , dbname , query , cancel )
178+ err = killQuery (dbt .db , dbname , query , killTimeout , cancel )
177179 if err != nil {
178180 dbt .Error (err )
179181 return
@@ -195,6 +197,62 @@ func testCancel(dbt *DBTest, ctx context.Context, cancel context.CancelFunc, que
195197 tx .Commit ()
196198}
197199
200+ func testCancelNoKill (dbt * DBTest , ctx context.Context , cancel context.CancelFunc , query string , queryFunc func () error ) {
201+ tx , err := dbt .db .BeginTx (context .Background (), nil )
202+ if err != nil {
203+ dbt .Fatal (err )
204+ return
205+ }
206+
207+ _ , err = tx .Exec ("LOCK TABLES test WRITE" )
208+ if err != nil {
209+ tx .Rollback ()
210+ dbt .Fatal (err )
211+ }
212+
213+ errChan := make (chan error )
214+ go func () {
215+ // This query will be canceled.
216+ err = queryFunc ()
217+ if err != nil && err != context .Canceled {
218+ errLog .Print (err )
219+ }
220+ if err != context .Canceled && ctx .Err () != context .Canceled {
221+ errChan <- fmt .Errorf ("expected context.Canceled, got %v" , err )
222+ return
223+ }
224+ errChan <- nil
225+ }()
226+
227+ // it is safe to not use timeouts here since they are inside the killQuery function
228+ err = killQuery (dbt .db , dbname , query , 500 * time .Millisecond , cancel )
229+ if err != expectedKilledErr {
230+ if err == nil {
231+ dbt .Errorf ("query kill expected to fail" )
232+ } else {
233+ dbt .Errorf (fmt .Sprintf ("unexpected error %s" , err ))
234+ }
235+ }
236+
237+ _ , err = tx .Exec ("UNLOCK TABLES" )
238+ if err != nil {
239+ tx .Rollback ()
240+ dbt .Fatal (err )
241+ }
242+ tx .Commit ()
243+
244+ <- errChan
245+ }
246+
247+ func getKillDSN () string {
248+ cfg , err := ParseDSN (dsn )
249+ if err != nil {
250+ panic (err )
251+ }
252+ cfg .KillQueryOnTimeout = true
253+ return cfg .FormatDSN ()
254+ }
255+
198256func TestMultiResultSet (t * testing.T ) {
199257 type result struct {
200258 values [][]int
@@ -385,12 +443,62 @@ func TestPingContext(t *testing.T) {
385443 })
386444}
387445
388- func TestContextCancelExec (t * testing.T ) {
446+ func TestContextCancelNoKill (t * testing.T ) {
389447 runTests (t , dsn , func (dbt * DBTest ) {
390448 dbt .mustExec ("CREATE TABLE test (v INTEGER)" )
391449 ctx , cancel := context .WithCancel (context .Background ())
392450 exec := "INSERT INTO test VALUES(1)"
393451
452+ testCancelNoKill (dbt , ctx , cancel , exec , func () error {
453+ _ , err := dbt .db .ExecContext (ctx , exec )
454+ return err
455+ })
456+
457+ // Check how many times the query is executed.
458+ var v int
459+ var err error
460+ for i := 0 ; i != 3 ; i ++ {
461+ err = nil
462+ if err := dbt .db .QueryRow ("SELECT COUNT(*) FROM test" ).Scan (& v ); err != nil {
463+ dbt .Fatalf ("%s" , err .Error ())
464+ return
465+ }
466+ if v != 1 {
467+ err = fmt .Errorf ("expected val to be 1, got %d" , v )
468+ }
469+
470+ if err != nil {
471+ time .Sleep (100 * time .Millisecond ) // wait while insert is executed after table lock released
472+ }
473+ }
474+ if err != nil {
475+ dbt .Error (err )
476+ return
477+ }
478+
479+ // Context is already canceled, so error should come before execution.
480+ if _ , err := dbt .db .ExecContext (ctx , "INSERT INTO test VALUES (1)" ); err == nil {
481+ dbt .Error ("expected error" )
482+ } else if err .Error () != "context canceled" {
483+ dbt .Fatalf ("unexpected error: %s" , err )
484+ }
485+
486+ // The second insert query will fail, so the table has no changes.
487+ if err := dbt .db .QueryRow ("SELECT COUNT(*) FROM test" ).Scan (& v ); err != nil {
488+ dbt .Fatalf ("%s" , err .Error ())
489+ }
490+ if v != 1 {
491+ dbt .Errorf ("expected val to be 1, got %d" , v )
492+ }
493+ })
494+ }
495+
496+ func TestContextCancelExec (t * testing.T ) {
497+ runTests (t , getKillDSN (), func (dbt * DBTest ) {
498+ dbt .mustExec ("CREATE TABLE test (v INTEGER)" )
499+ ctx , cancel := context .WithCancel (context .Background ())
500+ exec := "INSERT INTO test VALUES(1)"
501+
394502 testCancel (dbt , ctx , cancel , exec , func () error {
395503 _ , err := dbt .db .ExecContext (ctx , exec )
396504 return err
@@ -423,7 +531,7 @@ func TestContextCancelExec(t *testing.T) {
423531}
424532
425533func TestContextCancelQuery (t * testing.T ) {
426- runTests (t , dsn , func (dbt * DBTest ) {
534+ runTests (t , getKillDSN () , func (dbt * DBTest ) {
427535 dbt .mustExec ("CREATE TABLE test (v INTEGER)" )
428536 ctx , cancel := context .WithCancel (context .Background ())
429537 query := "SELECT 1 FROM test"
@@ -501,7 +609,7 @@ func TestContextCancelPrepare(t *testing.T) {
501609}
502610
503611func TestContextCancelStmtExec (t * testing.T ) {
504- runTests (t , dsn , func (dbt * DBTest ) {
612+ runTests (t , getKillDSN () , func (dbt * DBTest ) {
505613 dbt .mustExec ("CREATE TABLE test (v INTEGER)" )
506614 ctx , cancel := context .WithCancel (context .Background ())
507615 exec := "INSERT INTO test VALUES(1)"
@@ -528,7 +636,7 @@ func TestContextCancelStmtExec(t *testing.T) {
528636}
529637
530638func TestContextCancelStmtQuery (t * testing.T ) {
531- runTests (t , dsn , func (dbt * DBTest ) {
639+ runTests (t , getKillDSN () , func (dbt * DBTest ) {
532640 dbt .mustExec ("CREATE TABLE test (v INTEGER)" )
533641 ctx , cancel := context .WithCancel (context .Background ())
534642 query := "SELECT 1 FROM test"
@@ -555,7 +663,7 @@ func TestContextCancelStmtQuery(t *testing.T) {
555663}
556664
557665func TestContextCancelBegin (t * testing.T ) {
558- runTests (t , dsn , func (dbt * DBTest ) {
666+ runTests (t , getKillDSN () , func (dbt * DBTest ) {
559667 dbt .mustExec ("CREATE TABLE test (v INTEGER)" )
560668 ctx , cancel := context .WithCancel (context .Background ())
561669 query := "SELECT 1 FROM test"
0 commit comments