From 638272efb2aa3be419a0dd56cb546121a9b5ff7a Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Sat, 27 Mar 2021 22:03:23 +0800 Subject: [PATCH] planner: fix bug that mpp avg will throw column index out of bound error (#23604) (#23605) --- planner/core/integration_test.go | 42 +++++++++++++++++++ planner/core/task.go | 8 ++-- .../testdata/integration_serial_suite_in.json | 6 +++ .../integration_serial_suite_out.json | 20 +++++++++ 4 files changed, 72 insertions(+), 4 deletions(-) diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 14c7334202060..53f22e1694a92 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -713,6 +713,48 @@ func (s *testIntegrationSerialSuite) TestMPPWithBroadcastExchangeUnderNewCollati } } +func (s *testIntegrationSerialSuite) TestMPPAvgRewrite(c *C) { + defer collate.SetNewCollationEnabledForTest(false) + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists table_1") + tk.MustExec("create table table_1(id int not null, value decimal(10,2))") + tk.MustExec("insert into table_1 values(1,1),(2,2)") + tk.MustExec("analyze table table_1") + + // Create virtual tiflash replica info. + dom := domain.GetDomain(tk.Se) + is := dom.InfoSchema() + db, exists := is.SchemaByName(model.NewCIStr("test")) + c.Assert(exists, IsTrue) + for _, tblInfo := range db.Tables { + if tblInfo.Name.L == "table_1" { + tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ + Count: 1, + Available: true, + } + } + } + + collate.SetNewCollationEnabledForTest(true) + tk.MustExec("set @@session.tidb_isolation_read_engines = 'tiflash'") + tk.MustExec("set @@session.tidb_allow_mpp = 1") + var input []string + var output []struct { + SQL string + Plan []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + }) + res := tk.MustQuery(tt) + res.Check(testkit.Rows(output[i].Plan...)) + } +} + func (s *testIntegrationSerialSuite) TestAggPushDownEngine(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/planner/core/task.go b/planner/core/task.go index 24ed01ead04e6..4944cd7ce6fac 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -1480,9 +1480,9 @@ func (p *basePhysicalAgg) convertAvgForMPP() *PhysicalProjection { for i, aggFunc := range p.AggFuncs { if aggFunc.Name == ast.AggFuncAvg { // inset a count(column) - avgCount := *aggFunc + avgCount := aggFunc.Clone() avgCount.Name = ast.AggFuncCount - newAggFuncs = append(newAggFuncs, &avgCount) + newAggFuncs = append(newAggFuncs, avgCount) avgCount.RetTp = ft avgCountCol := &expression.Column{ UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), @@ -1490,9 +1490,9 @@ func (p *basePhysicalAgg) convertAvgForMPP() *PhysicalProjection { } newSchema.Append(avgCountCol) // insert a sum(column) - avgSum := *aggFunc + avgSum := aggFunc.Clone() avgSum.Name = ast.AggFuncSum - newAggFuncs = append(newAggFuncs, &avgSum) + newAggFuncs = append(newAggFuncs, avgSum) newSchema.Append(p.schema.Columns[i]) avgSumCol := p.schema.Columns[i] // avgSumCol/(case when avgCountCol=0 then 1 else avgCountCol end) diff --git a/planner/core/testdata/integration_serial_suite_in.json b/planner/core/testdata/integration_serial_suite_in.json index e1b912b6c3d63..9aa55d875c1e4 100644 --- a/planner/core/testdata/integration_serial_suite_in.json +++ b/planner/core/testdata/integration_serial_suite_in.json @@ -93,6 +93,12 @@ "explain format = 'brief' select /*+ broadcast_join(a,b) */ * from table_1 a, table_1 b where a.value = b.value" ] }, + { + "name": "TestMPPAvgRewrite", + "cases": [ + "explain format = 'brief' select /*+ avg_to_cop() */ id, avg(value+1),avg(value) from table_1 group by id" + ] + }, { "name": "TestReadFromStorageHint", "cases": [ diff --git a/planner/core/testdata/integration_serial_suite_out.json b/planner/core/testdata/integration_serial_suite_out.json index b214f24368532..ddc7ec381bb23 100644 --- a/planner/core/testdata/integration_serial_suite_out.json +++ b/planner/core/testdata/integration_serial_suite_out.json @@ -909,6 +909,26 @@ } ] }, + { + "Name": "TestMPPAvgRewrite", + "Cases": [ + { + "SQL": "explain format = 'brief' select /*+ avg_to_cop() */ id, avg(value+1),avg(value) from table_1 group by id", + "Plan": [ + "Projection 2.00 root test.table_1.id, Column#4, Column#5", + "└─TableReader 2.00 root data:ExchangeSender", + " └─ExchangeSender 2.00 batchCop[tiflash] ExchangeType: PassThrough", + " └─Projection 2.00 batchCop[tiflash] div(Column#4, cast(case(eq(Column#13, 0), 1, Column#13), decimal(20,0) BINARY))->Column#4, div(Column#5, cast(case(eq(Column#14, 0), 1, Column#14), decimal(20,0) BINARY))->Column#5, test.table_1.id", + " └─HashAgg 2.00 batchCop[tiflash] group by:test.table_1.id, funcs:sum(Column#15)->Column#13, funcs:sum(Column#16)->Column#4, funcs:sum(Column#17)->Column#14, funcs:sum(Column#18)->Column#5, funcs:firstrow(test.table_1.id)->test.table_1.id", + " └─ExchangeReceiver 2.00 batchCop[tiflash] ", + " └─ExchangeSender 2.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.table_1.id", + " └─HashAgg 2.00 batchCop[tiflash] group by:Column#29, funcs:count(Column#25)->Column#15, funcs:sum(Column#26)->Column#16, funcs:count(Column#27)->Column#17, funcs:sum(Column#28)->Column#18", + " └─Projection 2.00 batchCop[tiflash] plus(test.table_1.value, 1)->Column#25, plus(test.table_1.value, 1)->Column#26, test.table_1.value, test.table_1.value, test.table_1.id", + " └─TableFullScan 2.00 batchCop[tiflash] table:table_1 keep order:false" + ] + } + ] + }, { "Name": "TestReadFromStorageHint", "Cases": [