From a63dfc139d0ba8092c2c02e0bd1d61dc5b7cc835 Mon Sep 17 00:00:00 2001 From: Shenghui Wu <793703860@qq.com> Date: Mon, 4 Dec 2023 18:40:22 +0800 Subject: [PATCH] executor: fix tidb crash when shuffleExec quit unexpectedly (#48828) close pingcap/tidb#48230 --- pkg/executor/executor_failpoint_test.go | 23 +++++++++++++++++++++++ pkg/executor/shuffle.go | 20 +++++++++++++++----- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/pkg/executor/executor_failpoint_test.go b/pkg/executor/executor_failpoint_test.go index 15d52ccec3dda..c3537e1b639ae 100644 --- a/pkg/executor/executor_failpoint_test.go +++ b/pkg/executor/executor_failpoint_test.go @@ -653,3 +653,26 @@ func TestGetMvccByEncodedKeyRegionError(t *testing.T) { require.Equal(t, 1, len(resp.Info.Writes)) require.Equal(t, commitTs, resp.Info.Writes[0].CommitTs) } + +func TestShuffleExit(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(i int, j int, k int);") + tk.MustExec("insert into t1 VALUES (1,1,1),(2,2,2),(3,3,3),(4,4,4);") + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/shuffleError", "return(true)")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/executor/shuffleError")) + }() + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/shuffleExecFetchDataAndSplit", "return(true)")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/executor/shuffleExecFetchDataAndSplit")) + }() + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/shuffleWorkerRun", "panic(\"ShufflePanic\")")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/executor/shuffleWorkerRun")) + }() + err := tk.QueryToErr("SELECT SUM(i) OVER W FROM t1 WINDOW w AS (PARTITION BY j ORDER BY i) ORDER BY 1+SUM(i) OVER w;") + require.ErrorContains(t, err, "ShuffleExec.Next error") +} diff --git a/pkg/executor/shuffle.go b/pkg/executor/shuffle.go index 3e1188868313e..c7309cd62133f 100644 --- a/pkg/executor/shuffle.go +++ b/pkg/executor/shuffle.go @@ -17,6 +17,7 @@ package executor import ( "context" "sync" + "time" "github.com/pingcap/errors" "github.com/pingcap/failpoint" @@ -117,7 +118,7 @@ func (e *ShuffleExec) Open(ctx context.Context) error { e.prepared = false e.finishCh = make(chan struct{}, 1) - e.outputCh = make(chan *shuffleOutput, e.concurrency) + e.outputCh = make(chan *shuffleOutput, e.concurrency+len(e.dataSources)) for _, w := range e.workers { w.finishCh = e.finishCh @@ -203,13 +204,13 @@ func (e *ShuffleExec) Close() error { } func (e *ShuffleExec) prepare4ParallelExec(ctx context.Context) { + waitGroup := &sync.WaitGroup{} + waitGroup.Add(len(e.workers) + len(e.dataSources)) // create a goroutine for each dataSource to fetch and split data for i := range e.dataSources { - go e.fetchDataAndSplit(ctx, i) + go e.fetchDataAndSplit(ctx, i, waitGroup) } - waitGroup := &sync.WaitGroup{} - waitGroup.Add(len(e.workers)) for _, w := range e.workers { go w.run(ctx, waitGroup) } @@ -260,7 +261,7 @@ func recoveryShuffleExec(output chan *shuffleOutput, r interface{}) { logutil.BgLogger().Error("shuffle panicked", zap.Error(err), zap.Stack("stack")) } -func (e *ShuffleExec) fetchDataAndSplit(ctx context.Context, dataSourceIndex int) { +func (e *ShuffleExec) fetchDataAndSplit(ctx context.Context, dataSourceIndex int, waitGroup *sync.WaitGroup) { var ( err error workerIndices []int @@ -275,8 +276,16 @@ func (e *ShuffleExec) fetchDataAndSplit(ctx context.Context, dataSourceIndex int for _, w := range e.workers { close(w.receivers[dataSourceIndex].inputCh) } + waitGroup.Done() }() + failpoint.Inject("shuffleExecFetchDataAndSplit", func(val failpoint.Value) { + if val.(bool) { + time.Sleep(100 * time.Millisecond) + panic("shuffleExecFetchDataAndSplitPanic") + } + }) + for { err = exec.Next(ctx, e.dataSources[dataSourceIndex], chk) if err != nil { @@ -391,6 +400,7 @@ func (e *shuffleWorker) run(ctx context.Context, waitGroup *sync.WaitGroup) { waitGroup.Done() }() + failpoint.Inject("shuffleWorkerRun", nil) for { select { case <-e.finishCh: