Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: fix aggregation hint didn't work in some cases #11996

Merged
merged 3 commits into from
Sep 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ func (er *expressionRewriter) handleCompareSubquery(ctx context.Context, v *ast.
// it will be rewrote to t.id < (select max(s.id) from s).
func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression.Expression, np LogicalPlan, useMin bool, cmpFunc string, all bool) {
plan4Agg := LogicalAggregation{}.Init(er.sctx)
if hint := er.b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
plan4Agg.SetChildren(np)

// Create a "max" or "min" aggregation.
Expand Down Expand Up @@ -567,6 +570,9 @@ func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np
plan4Agg := LogicalAggregation{
AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc},
}.Init(er.sctx)
if hint := er.b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
plan4Agg.SetChildren(np)
firstRowResultCol := &expression.Column{
ColName: model.NewCIStr("col_firstRow"),
Expand Down Expand Up @@ -601,6 +607,9 @@ func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np
plan4Agg := LogicalAggregation{
AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc},
}.Init(er.sctx)
if hint := er.b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
plan4Agg.SetChildren(np)
firstRowResultCol := &expression.Column{
ColName: model.NewCIStr("col_firstRow"),
Expand Down
9 changes: 6 additions & 3 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu
b.optFlag = b.optFlag | flagEliminateProjection

plan4Agg := LogicalAggregation{AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(aggFuncList))}.Init(b.ctx)
if hint := b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
schema4Agg := expression.NewSchema(make([]*expression.Column, 0, len(aggFuncList)+p.Schema().Len())...)
// aggIdxMap maps the old index to new index after applying common aggregation functions elimination.
aggIndexMap := make(map[int]int)
Expand Down Expand Up @@ -149,9 +152,6 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu
plan4Agg.GroupByItems = gbyItems
plan4Agg.SetSchema(schema4Agg)
plan4Agg.collectGroupByColumns()
if hint := b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
return plan4Agg, aggIndexMap, nil
}

Expand Down Expand Up @@ -794,6 +794,9 @@ func (b *PlanBuilder) buildDistinct(child LogicalPlan, length int) (*LogicalAggr
AggFuncs: make([]*aggregation.AggFuncDesc, 0, child.Schema().Len()),
GroupByItems: expression.Column2Exprs(child.Schema().Clone().Columns[:length]),
}.Init(b.ctx)
if hint := b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
plan4Agg.collectGroupByColumns()
for _, col := range child.Schema().Columns {
aggDesc, err := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false)
Expand Down
63 changes: 43 additions & 20 deletions planner/core/physical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1603,56 +1603,79 @@ func (s *testPlanSuite) TestAggregationHints(c *C) {
sessionVars.HashAggPartialConcurrency = 1

tests := []struct {
sql string
best string
warning string
sql string
best string
warning string
aggPushDown bool
}{
// without Aggregation hints
{
sql: "select count(*) from t t1, t t2 where t1.a = t2.b",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->StreamAgg",
warning: "",
sql: "select count(*) from t t1, t t2 where t1.a = t2.b",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->StreamAgg",
},
{
sql: "select count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->HashAgg",
warning: "",
sql: "select count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->HashAgg",
},
// with Aggregation hints
{
sql: "select /*+ HASH_AGG() */ count(*) from t t1, t t2 where t1.a = t2.b",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->HashAgg",
warning: "",
sql: "select /*+ HASH_AGG() */ count(*) from t t1, t t2 where t1.a = t2.b",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->HashAgg",
},
{
sql: "select /*+ STREAM_AGG() */ count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->Sort->StreamAgg",
warning: "",
sql: "select /*+ STREAM_AGG() */ count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->Sort->StreamAgg",
},
// test conflict warning
{
sql: "select /*+ HASH_AGG() STREAM_AGG() */ count(*) from t t1, t t2 where t1.a = t2.b",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->StreamAgg",
warning: "[planner:1815]Optimizer aggregation hints are conflicted",
},
// additional test
{
sql: "select /*+ STREAM_AGG() */ distinct a from t",
best: "TableReader(Table(t)->StreamAgg)->StreamAgg",
},
{
sql: "select /*+ HASH_AGG() */ t1.a from t t1 where t1.a < any(select t2.b from t t2)",
best: "LeftHashJoin{TableReader(Table(t)->Sel([if(isnull(test.t1.a), <nil>, 1)]))->TableReader(Table(t)->HashAgg)->HashAgg->Sel([ne(agg_col_cnt, 0)])}->Projection->Projection",
},
{
sql: "select /*+ hash_agg() */ t1.a from t t1 where t1.a != any(select t2.b from t t2)",
best: "LeftHashJoin{TableReader(Table(t)->Sel([if(isnull(test.t1.a), <nil>, 1)]))->TableReader(Table(t))->Projection->HashAgg->Sel([ne(agg_col_cnt, 0)])}->Projection->Projection",
},
{
sql: "select /*+ hash_agg() */ t1.a from t t1 where t1.a = all(select t2.b from t t2)",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection->HashAgg}->Projection->Projection",
},
{
sql: "select /*+ STREAM_AGG() */ sum(t1.a) from t t1 join t t2 on t1.b = t2.b group by t1.b",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Sort->Projection->StreamAgg}(test.t2.b,test.t1.b)->HashAgg",
warning: "[planner:1815]Optimizer Hint STREAM_AGG is inapplicable",
aggPushDown: true,
},
}
ctx := context.Background()
for i, test := range tests {
comment := Commentf("case:%v sql:%s", i, test)
se.GetSessionVars().StmtCtx.SetWarnings(nil)
se.GetSessionVars().AllowAggPushDown = test.aggPushDown

stmt, err := s.ParseOneStmt(test.sql, "", "")
c.Assert(err, IsNil, comment)

p, err := planner.Optimize(ctx, se, stmt, s.is)
c.Assert(err, IsNil)
c.Assert(core.ToString(p), Equals, test.best)
c.Assert(core.ToString(p), Equals, test.best, comment)

warnings := se.GetSessionVars().StmtCtx.GetWarnings()
if test.warning == "" {
c.Assert(len(warnings), Equals, 0)
c.Assert(len(warnings), Equals, 0, comment)
} else {
c.Assert(len(warnings), Equals, 1)
c.Assert(warnings[0].Level, Equals, stmtctx.WarnLevelWarning)
c.Assert(warnings[0].Err.Error(), Equals, test.warning)
c.Assert(len(warnings), Equals, 1, comment)
c.Assert(warnings[0].Level, Equals, stmtctx.WarnLevelWarning, comment)
c.Assert(warnings[0].Err.Error(), Equals, test.warning, comment)
}
}
}
Expand Down
22 changes: 12 additions & 10 deletions planner/core/rule_aggregation_push_down.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func (a *aggregationPushDownSolver) decompose(ctx sessionctx.Context, aggFunc *a
// tryToPushDownAgg tries to push down an aggregate function into a join path. If all aggFuncs are first row, we won't
// process it temporarily. If not, We will add additional group by columns and first row functions. We make a new aggregation operator.
// If the pushed aggregation is grouped by unique key, it's no need to push it down.
func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int) (_ LogicalPlan, err error) {
func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int, preferAggType uint) (_ LogicalPlan, err error) {
child := join.children[childIdx]
if aggregation.IsAllFirstRow(aggFuncs) {
return child, nil
Expand All @@ -204,7 +204,7 @@ func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.Agg
return child, nil
}
}
agg, err := a.makeNewAgg(join.ctx, aggFuncs, gbyCols)
agg, err := a.makeNewAgg(join.ctx, aggFuncs, gbyCols, preferAggType)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -247,10 +247,11 @@ func (a *aggregationPushDownSolver) checkAnyCountAndSum(aggFuncs []*aggregation.
return false
}

func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column) (*LogicalAggregation, error) {
func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, preferAggType uint) (*LogicalAggregation, error) {
agg := LogicalAggregation{
GroupByItems: expression.Column2Exprs(gbyCols),
groupByCols: gbyCols,
GroupByItems: expression.Column2Exprs(gbyCols),
groupByCols: gbyCols,
preferAggType: preferAggType,
}.Init(ctx)
aggLen := len(aggFuncs) + len(gbyCols)
newAggFuncDescs := make([]*aggregation.AggFuncDesc, 0, aggLen)
Expand Down Expand Up @@ -282,8 +283,9 @@ func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs
func (a *aggregationPushDownSolver) pushAggCrossUnion(agg *LogicalAggregation, unionSchema *expression.Schema, unionChild LogicalPlan) LogicalPlan {
ctx := agg.ctx
newAgg := LogicalAggregation{
AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(agg.AggFuncs)),
GroupByItems: make([]expression.Expression, 0, len(agg.GroupByItems)),
AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(agg.AggFuncs)),
GroupByItems: make([]expression.Expression, 0, len(agg.GroupByItems)),
preferAggType: agg.preferAggType,
}.Init(ctx)
newAgg.SetSchema(agg.schema.Clone())
for _, aggFunc := range agg.AggFuncs {
Expand Down Expand Up @@ -340,15 +342,15 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e
if rightInvalid {
rChild = join.children[1]
} else {
rChild, err = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1)
rChild, err = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1, agg.preferAggType)
if err != nil {
return nil, err
}
}
if leftInvalid {
lChild = join.children[0]
} else {
lChild, err = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0)
lChild, err = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0, agg.preferAggType)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -380,7 +382,7 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e
} else if union, ok1 := child.(*LogicalUnionAll); ok1 {
var gbyCols []*expression.Column
gbyCols = expression.ExtractColumnsFromExpressions(gbyCols, agg.GroupByItems, nil)
pushedAgg, err := a.makeNewAgg(agg.ctx, agg.AggFuncs, gbyCols)
pushedAgg, err := a.makeNewAgg(agg.ctx, agg.AggFuncs, gbyCols, agg.preferAggType)
if err != nil {
return nil, err
}
Expand Down