diff --git a/pkg/executor/executor_failpoint_test.go b/pkg/executor/executor_failpoint_test.go index 01b4ff86e1469..238b2838fce3f 100644 --- a/pkg/executor/executor_failpoint_test.go +++ b/pkg/executor/executor_failpoint_test.go @@ -652,3 +652,26 @@ func TestGetMvccByEncodedKeyRegionError(t *testing.T) { require.Equal(t, 1, len(resp.Info.Writes)) require.Equal(t, commitTs, resp.Info.Writes[0].CommitTs) } + +func TestShuffleExit(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(i int, j int, k int);") + tk.MustExec("insert into t1 VALUES (1,1,1),(2,2,2),(3,3,3),(4,4,4);") + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/shuffleError", "return(true)")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/executor/shuffleError")) + }() + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/shuffleExecFetchDataAndSplit", "return(true)")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/executor/shuffleExecFetchDataAndSplit")) + }() + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/shuffleWorkerRun", "panic(\"ShufflePanic\")")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/executor/shuffleWorkerRun")) + }() + err := tk.QueryToErr("SELECT SUM(i) OVER W FROM t1 WINDOW w AS (PARTITION BY j ORDER BY i) ORDER BY 1+SUM(i) OVER w;") + require.ErrorContains(t, err, "ShuffleExec.Next error") +} diff --git a/pkg/executor/shuffle.go b/pkg/executor/shuffle.go index c1b9a4a0faecd..7027f5bdc7e26 100644 --- a/pkg/executor/shuffle.go +++ b/pkg/executor/shuffle.go @@ -17,6 +17,7 @@ package executor import ( "context" "sync" + "time" "github.com/pingcap/errors" "github.com/pingcap/failpoint" @@ -116,7 +117,7 @@ func (e *ShuffleExec) Open(ctx context.Context) error { e.prepared = false e.finishCh = make(chan struct{}, 1) - e.outputCh = make(chan *shuffleOutput, e.concurrency) + e.outputCh = make(chan *shuffleOutput, e.concurrency+len(e.dataSources)) for _, w := range e.workers { w.finishCh = e.finishCh @@ -202,13 +203,13 @@ func (e *ShuffleExec) Close() error { } func (e *ShuffleExec) prepare4ParallelExec(ctx context.Context) { + waitGroup := &sync.WaitGroup{} + waitGroup.Add(len(e.workers) + len(e.dataSources)) // create a goroutine for each dataSource to fetch and split data for i := range e.dataSources { - go e.fetchDataAndSplit(ctx, i) + go e.fetchDataAndSplit(ctx, i, waitGroup) } - waitGroup := &sync.WaitGroup{} - waitGroup.Add(len(e.workers)) for _, w := range e.workers { go w.run(ctx, waitGroup) } @@ -259,7 +260,7 @@ func recoveryShuffleExec(output chan *shuffleOutput, r interface{}) { logutil.BgLogger().Error("shuffle panicked", zap.Error(err), zap.Stack("stack")) } -func (e *ShuffleExec) fetchDataAndSplit(ctx context.Context, dataSourceIndex int) { +func (e *ShuffleExec) fetchDataAndSplit(ctx context.Context, dataSourceIndex int, waitGroup *sync.WaitGroup) { var ( err error workerIndices []int @@ -274,8 +275,16 @@ func (e *ShuffleExec) fetchDataAndSplit(ctx context.Context, dataSourceIndex int for _, w := range e.workers { close(w.receivers[dataSourceIndex].inputCh) } + waitGroup.Done() }() + failpoint.Inject("shuffleExecFetchDataAndSplit", func(val failpoint.Value) { + if val.(bool) { + time.Sleep(100 * time.Millisecond) + panic("shuffleExecFetchDataAndSplitPanic") + } + }) + for { err = exec.Next(ctx, e.dataSources[dataSourceIndex], chk) if err != nil { @@ -390,6 +399,7 @@ func (e *shuffleWorker) run(ctx context.Context, waitGroup *sync.WaitGroup) { waitGroup.Done() }() + failpoint.Inject("shuffleWorkerRun", nil) for { select { case <-e.finishCh: