diff --git a/executor/cte.go b/executor/cte.go index 4ea7ae184fcc6..569d59298d9c6 100644 --- a/executor/cte.go +++ b/executor/cte.go @@ -231,6 +231,12 @@ func (e *CTEExec) Close() (err error) { } func (e *CTEExec) computeSeedPart(ctx context.Context) (err error) { + defer func() { + if r := recover(); r != nil && err == nil { + err = errors.Errorf("%v", r) + } + }() + failpoint.Inject("testCTESeedPanic", nil) e.curIter = 0 e.iterInTbl.SetIter(e.curIter) chks := make([]*chunk.Chunk, 0, 10) @@ -240,13 +246,13 @@ func (e *CTEExec) computeSeedPart(ctx context.Context) (err error) { } chk := tryNewCacheChunk(e.seedExec) if err = Next(ctx, e.seedExec, chk); err != nil { - return err + return } if chk.NumRows() == 0 { break } if chk, err = e.tryDedupAndAdd(chk, e.iterInTbl, e.hashTbl); err != nil { - return err + return } chks = append(chks, chk) } @@ -254,18 +260,24 @@ func (e *CTEExec) computeSeedPart(ctx context.Context) (err error) { // Just adding is ok. for _, chk := range chks { if err = e.resTbl.Add(chk); err != nil { - return err + return } } e.curIter++ e.iterInTbl.SetIter(e.curIter) - return nil + return } func (e *CTEExec) computeRecursivePart(ctx context.Context) (err error) { + defer func() { + if r := recover(); r != nil && err == nil { + err = errors.Errorf("%v", r) + } + }() + failpoint.Inject("testCTERecursivePanic", nil) if e.recursiveExec == nil || e.iterInTbl.NumChunks() == 0 { - return nil + return } if e.curIter > e.ctx.GetSessionVars().CTEMaxRecursionDepth { @@ -273,17 +285,17 @@ func (e *CTEExec) computeRecursivePart(ctx context.Context) (err error) { } if e.limitDone(e.resTbl) { - return nil + return } for { chk := tryNewCacheChunk(e.recursiveExec) if err = Next(ctx, e.recursiveExec, chk); err != nil { - return err + return } if chk.NumRows() == 0 { if err = e.setupTblsForNewIteration(); err != nil { - return err + return } if e.limitDone(e.resTbl) { break @@ -300,18 +312,18 @@ func (e *CTEExec) computeRecursivePart(ctx context.Context) (err error) { // Make sure iterInTbl is setup before Close/Open, // because some executors will read iterInTbl in Open() (like IndexLookupJoin). if err = e.recursiveExec.Close(); err != nil { - return err + return } if err = e.recursiveExec.Open(ctx); err != nil { - return err + return } } else { if err = e.iterOutTbl.Add(chk); err != nil { - return err + return } } } - return nil + return } // Get next chunk from resTbl for limit. diff --git a/executor/cte_test.go b/executor/cte_test.go index 368d4bfd07796..9d4214dce9438 100644 --- a/executor/cte_test.go +++ b/executor/cte_test.go @@ -449,3 +449,24 @@ func TestCTEsInView(t *testing.T) { tk.MustExec("use test1;") tk.MustQuery("select * from test.v;").Check(testkit.Rows("1")) } + +func TestCTEPanic(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test;") + tk.MustExec("create table t1(c1 int)") + tk.MustExec("insert into t1 values(1), (2), (3)") + + fpPathPrefix := "github.com/pingcap/tidb/executor/" + fp := "testCTESeedPanic" + require.NoError(t, failpoint.Enable(fpPathPrefix+fp, fmt.Sprintf(`panic("%s")`, fp))) + err := tk.QueryToErr("with recursive cte1 as (select c1 from t1 union all select c1 + 1 from cte1 where c1 < 5) select t_alias_1.c1 from cte1 as t_alias_1 inner join cte1 as t_alias_2 on t_alias_1.c1 = t_alias_2.c1 order by c1") + require.Contains(t, err.Error(), fp) + require.NoError(t, failpoint.Disable(fpPathPrefix+fp)) + + fp = "testCTERecursivePanic" + require.NoError(t, failpoint.Enable(fpPathPrefix+fp, fmt.Sprintf(`panic("%s")`, fp))) + err = tk.QueryToErr("with recursive cte1 as (select c1 from t1 union all select c1 + 1 from cte1 where c1 < 5) select t_alias_1.c1 from cte1 as t_alias_1 inner join cte1 as t_alias_2 on t_alias_1.c1 = t_alias_2.c1 order by c1") + require.Contains(t, err.Error(), fp) + require.NoError(t, failpoint.Disable(fpPathPrefix+fp)) +} diff --git a/util/cteutil/storage.go b/util/cteutil/storage.go index 02d82cef9e660..dea6fd632e42b 100644 --- a/util/cteutil/storage.go +++ b/util/cteutil/storage.go @@ -129,13 +129,14 @@ func (s *StorageRC) DerefAndClose() (err error) { if s.refCnt < 0 { return errors.New("Storage ref count is less than zero") } else if s.refCnt == 0 { - // TODO: unreg memtracker + s.refCnt = -1 + s.done = false + s.err = nil + s.iter = 0 if err = s.rc.Close(); err != nil { return err } - if err = s.resetAll(); err != nil { - return err - } + s.rc = nil } return nil } @@ -155,7 +156,7 @@ func (s *StorageRC) SwapData(other Storage) (err error) { // Reopen impls Storage Reopen interface. func (s *StorageRC) Reopen() (err error) { - if err = s.rc.Reset(); err != nil { + if err = s.rc.Close(); err != nil { return err } s.iter = 0 @@ -265,18 +266,6 @@ func (s *StorageRC) ActionSpillForTest() *chunk.SpillDiskAction { return s.rc.ActionSpillForTest() } -func (s *StorageRC) resetAll() error { - s.refCnt = -1 - s.done = false - s.err = nil - s.iter = 0 - if err := s.rc.Reset(); err != nil { - return err - } - s.rc = nil - return nil -} - func (s *StorageRC) valid() bool { return s.refCnt > 0 && s.rc != nil }