diff --git a/executor/cte.go b/executor/cte.go index ff3121be362ab..49848e4e75ce1 100644 --- a/executor/cte.go +++ b/executor/cte.go @@ -233,7 +233,7 @@ func (e *CTEExec) computeSeedPart(ctx context.Context) (err error) { err = errors.Errorf("%v", r) } }() - failpoint.Inject("testCTEPanic", nil) + failpoint.Inject("testCTESeedPanic", nil) e.curIter = 0 e.iterInTbl.SetIter(e.curIter) chks := make([]*chunk.Chunk, 0, 10) @@ -272,6 +272,7 @@ func (e *CTEExec) computeRecursivePart(ctx context.Context) (err error) { err = errors.Errorf("%v", r) } }() + failpoint.Inject("testCTERecursivePanic", nil) if e.recursiveExec == nil || e.iterInTbl.NumChunks() == 0 { return } diff --git a/executor/cte_test.go b/executor/cte_test.go index 0c703169570c3..a166fabee02c4 100644 --- a/executor/cte_test.go +++ b/executor/cte_test.go @@ -461,7 +461,7 @@ func TestCTEsInView(t *testing.T) { tk.MustQuery("select * from test.v;").Check(testkit.Rows("1")) } -func TestCTESeedPanic(t *testing.T) { +func TestCTEPanic(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) tk.MustExec("use test;") @@ -469,7 +469,15 @@ func TestCTESeedPanic(t *testing.T) { tk.MustExec("insert into t1 values(1), (2), (3)") fpPathPrefix := "github.com/pingcap/tidb/executor/" - require.NoError(t, failpoint.Enable(fpPathPrefix+"testCTEPanic", `panic("testCTEPanic")`)) + fp := "testCTESeedPanic" + require.NoError(t, failpoint.Enable(fpPathPrefix+fp, fmt.Sprintf(`panic("%s")`, fp))) err := tk.QueryToErr("with recursive cte1 as (select c1 from t1 union all select c1 + 1 from cte1 where c1 < 5) select t_alias_1.c1 from cte1 as t_alias_1 inner join cte1 as t_alias_2 on t_alias_1.c1 = t_alias_2.c1 order by c1") - require.Contains(t, err.Error(), "testCTEPanic") + require.Contains(t, err.Error(), fp) + require.NoError(t, failpoint.Disable(fpPathPrefix+fp)) + + fp = "testCTERecursivePanic" + require.NoError(t, failpoint.Enable(fpPathPrefix+fp, fmt.Sprintf(`panic("%s")`, fp))) + err = tk.QueryToErr("with recursive cte1 as (select c1 from t1 union all select c1 + 1 from cte1 where c1 < 5) select t_alias_1.c1 from cte1 as t_alias_1 inner join cte1 as t_alias_2 on t_alias_1.c1 = t_alias_2.c1 order by c1") + require.Contains(t, err.Error(), fp) + require.NoError(t, failpoint.Disable(fpPathPrefix+fp)) }