Skip to content

Commit

Permalink
planner: fix bug that mpp avg will throw column index out of bound er…
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-srebot authored Mar 27, 2021
1 parent 7e3d70d commit 638272e
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 4 deletions.
42 changes: 42 additions & 0 deletions planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,48 @@ func (s *testIntegrationSerialSuite) TestMPPWithBroadcastExchangeUnderNewCollati
}
}

func (s *testIntegrationSerialSuite) TestMPPAvgRewrite(c *C) {
defer collate.SetNewCollationEnabledForTest(false)
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists table_1")
tk.MustExec("create table table_1(id int not null, value decimal(10,2))")
tk.MustExec("insert into table_1 values(1,1),(2,2)")
tk.MustExec("analyze table table_1")

// Create virtual tiflash replica info.
dom := domain.GetDomain(tk.Se)
is := dom.InfoSchema()
db, exists := is.SchemaByName(model.NewCIStr("test"))
c.Assert(exists, IsTrue)
for _, tblInfo := range db.Tables {
if tblInfo.Name.L == "table_1" {
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
Count: 1,
Available: true,
}
}
}

collate.SetNewCollationEnabledForTest(true)
tk.MustExec("set @@session.tidb_isolation_read_engines = 'tiflash'")
tk.MustExec("set @@session.tidb_allow_mpp = 1")
var input []string
var output []struct {
SQL string
Plan []string
}
s.testData.GetTestCases(c, &input, &output)
for i, tt := range input {
s.testData.OnRecord(func() {
output[i].SQL = tt
output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
})
res := tk.MustQuery(tt)
res.Check(testkit.Rows(output[i].Plan...))
}
}

func (s *testIntegrationSerialSuite) TestAggPushDownEngine(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
8 changes: 4 additions & 4 deletions planner/core/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -1480,19 +1480,19 @@ func (p *basePhysicalAgg) convertAvgForMPP() *PhysicalProjection {
for i, aggFunc := range p.AggFuncs {
if aggFunc.Name == ast.AggFuncAvg {
// inset a count(column)
avgCount := *aggFunc
avgCount := aggFunc.Clone()
avgCount.Name = ast.AggFuncCount
newAggFuncs = append(newAggFuncs, &avgCount)
newAggFuncs = append(newAggFuncs, avgCount)
avgCount.RetTp = ft
avgCountCol := &expression.Column{
UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(),
RetType: ft,
}
newSchema.Append(avgCountCol)
// insert a sum(column)
avgSum := *aggFunc
avgSum := aggFunc.Clone()
avgSum.Name = ast.AggFuncSum
newAggFuncs = append(newAggFuncs, &avgSum)
newAggFuncs = append(newAggFuncs, avgSum)
newSchema.Append(p.schema.Columns[i])
avgSumCol := p.schema.Columns[i]
// avgSumCol/(case when avgCountCol=0 then 1 else avgCountCol end)
Expand Down
6 changes: 6 additions & 0 deletions planner/core/testdata/integration_serial_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@
"explain format = 'brief' select /*+ broadcast_join(a,b) */ * from table_1 a, table_1 b where a.value = b.value"
]
},
{
"name": "TestMPPAvgRewrite",
"cases": [
"explain format = 'brief' select /*+ avg_to_cop() */ id, avg(value+1),avg(value) from table_1 group by id"
]
},
{
"name": "TestReadFromStorageHint",
"cases": [
Expand Down
20 changes: 20 additions & 0 deletions planner/core/testdata/integration_serial_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,26 @@
}
]
},
{
"Name": "TestMPPAvgRewrite",
"Cases": [
{
"SQL": "explain format = 'brief' select /*+ avg_to_cop() */ id, avg(value+1),avg(value) from table_1 group by id",
"Plan": [
"Projection 2.00 root test.table_1.id, Column#4, Column#5",
"└─TableReader 2.00 root data:ExchangeSender",
" └─ExchangeSender 2.00 batchCop[tiflash] ExchangeType: PassThrough",
" └─Projection 2.00 batchCop[tiflash] div(Column#4, cast(case(eq(Column#13, 0), 1, Column#13), decimal(20,0) BINARY))->Column#4, div(Column#5, cast(case(eq(Column#14, 0), 1, Column#14), decimal(20,0) BINARY))->Column#5, test.table_1.id",
" └─HashAgg 2.00 batchCop[tiflash] group by:test.table_1.id, funcs:sum(Column#15)->Column#13, funcs:sum(Column#16)->Column#4, funcs:sum(Column#17)->Column#14, funcs:sum(Column#18)->Column#5, funcs:firstrow(test.table_1.id)->test.table_1.id",
" └─ExchangeReceiver 2.00 batchCop[tiflash] ",
" └─ExchangeSender 2.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.table_1.id",
" └─HashAgg 2.00 batchCop[tiflash] group by:Column#29, funcs:count(Column#25)->Column#15, funcs:sum(Column#26)->Column#16, funcs:count(Column#27)->Column#17, funcs:sum(Column#28)->Column#18",
" └─Projection 2.00 batchCop[tiflash] plus(test.table_1.value, 1)->Column#25, plus(test.table_1.value, 1)->Column#26, test.table_1.value, test.table_1.value, test.table_1.id",
" └─TableFullScan 2.00 batchCop[tiflash] table:table_1 keep order:false"
]
}
]
},
{
"Name": "TestReadFromStorageHint",
"Cases": [
Expand Down

0 comments on commit 638272e

Please sign in to comment.