Skip to content

Commit

Permalink
planner/cascades: add transformation rule PushSelDownAggregation (#13106
Browse files Browse the repository at this point in the history
)
  • Loading branch information
francis0407 authored Nov 5, 2019
1 parent 25fa7ec commit 5793040
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 26 deletions.
4 changes: 3 additions & 1 deletion planner/cascades/testdata/transformation_rules_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
"select a, b from (select a, b from t as t1 order by a) as t2 where t2.a > 10",
"select a, b from (select a, b, a+b as a_b from t as t1) as t2 where a_b > 10 and b = 1",
"select a, @i:=@i+1 as ii from (select a, @i:=0 from t as t1) as t2 where @i < 10",
"select a, @i:=@i+1 as ii from (select a, @i:=0 from t as t1) as t2 where @i < 10 and a > 10"
"select a, @i:=@i+1 as ii from (select a, @i:=0 from t as t1) as t2 where @i < 10 and a > 10",
"select a, max(b) from t group by a having a > 1",
"select a, avg(b) from t group by a having a > 1 and max(b) > 10"
]
}
]
30 changes: 30 additions & 0 deletions planner/cascades/testdata/transformation_rules_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,36 @@
"Group#4 Schema:[Column#1], UniqueKey:[Column#1]",
" TableScan_10 table:t1, pk col:Column#1, cond:[gt(Column#1, 10)]"
]
},
{
"SQL": "select a, max(b) from t group by a having a > 1",
"Result": [
"Group#0 Schema:[Column#14,Column#15], UniqueKey:[Column#14,Column#14]",
" Projection_3 input:[Group#1], Column#1, Column#13",
"Group#1 Schema:[Column#13,Column#1], UniqueKey:[Column#1,Column#1]",
" Aggregation_2 input:[Group#2], group by:Column#1, funcs:max(Column#2), firstrow(Column#1)",
"Group#2 Schema:[Column#1,Column#2], UniqueKey:[Column#1]",
" TableGather_6 input:[Group#3]",
"Group#3 Schema:[Column#1,Column#2], UniqueKey:[Column#1]",
" TableScan_10 table:t, pk col:Column#1, cond:[gt(Column#1, 1)]"
]
},
{
"SQL": "select a, avg(b) from t group by a having a > 1 and max(b) > 10",
"Result": [
"Group#0 Schema:[Column#19,Column#20], UniqueKey:[Column#19,Column#19]",
" Projection_5 input:[Group#1], Column#15, Column#16",
"Group#1 Schema:[Column#15,Column#16,Column#18], UniqueKey:[Column#15,Column#15]",
" Projection_3 input:[Group#2], Column#1, Column#13, Column#14",
"Group#2 Schema:[Column#13,Column#14,Column#1], UniqueKey:[Column#1,Column#1]",
" Selection_10 input:[Group#3], gt(Column#14, 10)",
"Group#3 Schema:[Column#13,Column#14,Column#1], UniqueKey:[Column#1,Column#1]",
" Aggregation_2 input:[Group#4], group by:Column#1, funcs:avg(Column#2), max(Column#2), firstrow(Column#1)",
"Group#4 Schema:[Column#1,Column#2], UniqueKey:[Column#1]",
" TableGather_7 input:[Group#5]",
"Group#5 Schema:[Column#1,Column#2], UniqueKey:[Column#1]",
" TableScan_12 table:t, pk col:Column#1, cond:[gt(Column#1, 1)]"
]
}
]
}
Expand Down
88 changes: 88 additions & 0 deletions planner/cascades/transformation_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ const (
rulePushSelDownTableGather
rulePushSelDownSort
rulePushSelDownProjection
rulePushSelDownAggregation
ruleEnumeratePaths
)

Expand All @@ -58,6 +59,7 @@ var transformationRuleList = []Transformation{
&PushSelDownTableGather{},
&PushSelDownSort{},
&PushSelDownProjection{},
&PushSelDownAggregation{},
&EnumeratePaths{},
}

Expand All @@ -67,6 +69,7 @@ var defaultTransformationMap = map[memo.Operand][]TransformationID{
rulePushSelDownTableGather,
rulePushSelDownSort,
rulePushSelDownProjection,
rulePushSelDownAggregation,
},
memo.OperandDataSource: {
ruleEnumeratePaths,
Expand Down Expand Up @@ -323,3 +326,88 @@ func (r *PushSelDownProjection) OnTransform(old *memo.ExprIter) (newExprs []*mem
newTopSelExpr.SetChildren(newProjGroup)
return []*memo.GroupExpr{newTopSelExpr}, true, false, nil
}

// PushSelDownAggregation pushes Selection down to the child of Aggregation.
type PushSelDownAggregation struct {
}

// GetPattern implements Transformation interface.
// The pattern of this rule is `Selection -> Aggregation`.
func (r *PushSelDownAggregation) GetPattern() *memo.Pattern {
return memo.BuildPattern(
memo.OperandSelection,
memo.EngineAll,
memo.NewPattern(memo.OperandAggregation, memo.EngineAll),
)
}

// Match implements Transformation interface.
func (r *PushSelDownAggregation) Match(expr *memo.ExprIter) bool {
return true
}

// OnTransform implements Transformation interface.
// It will transform `sel->agg->x` to `agg->sel->x` or `sel->agg->sel->x`
// or just keep the selection unchanged.
func (r *PushSelDownAggregation) OnTransform(old *memo.ExprIter) (newExprs []*memo.GroupExpr, eraseOld bool, eraseAll bool, err error) {
sel := old.GetExpr().ExprNode.(*plannercore.LogicalSelection)
agg := old.Children[0].GetExpr().ExprNode.(*plannercore.LogicalAggregation)
var pushedExprs []expression.Expression
var remainedExprs []expression.Expression
exprsOriginal := make([]expression.Expression, 0, len(agg.AggFuncs))
for _, aggFunc := range agg.AggFuncs {
exprsOriginal = append(exprsOriginal, aggFunc.Args[0])
}
groupByColumns := expression.NewSchema(agg.GetGroupByCols()...)
for _, cond := range sel.Conditions {
switch cond.(type) {
case *expression.Constant:
// Consider SQL list "select sum(b) from t group by a having 1=0". "1=0" is a constant predicate which should be
// retained and pushed down at the same time. Because we will get a wrong query result that contains one column
// with value 0 rather than an empty query result.
pushedExprs = append(pushedExprs, cond)
remainedExprs = append(remainedExprs, cond)
case *expression.ScalarFunction:
extractedCols := expression.ExtractColumns(cond)
canPush := true
for _, col := range extractedCols {
if !groupByColumns.Contains(col) {
canPush = false
break
}
}
if canPush {
// TODO: Don't substitute since they should be the same column.
newCond := expression.ColumnSubstitute(cond, agg.Schema(), exprsOriginal)
pushedExprs = append(pushedExprs, newCond)
} else {
remainedExprs = append(remainedExprs, cond)
}
default:
remainedExprs = append(remainedExprs, cond)
}
}
// If no condition can be pushed, keep the selection unchanged.
if len(pushedExprs) == 0 {
return nil, false, false, nil
}
sctx := sel.SCtx()
childGroup := old.Children[0].GetExpr().Children[0]
pushedSel := plannercore.LogicalSelection{Conditions: pushedExprs}.Init(sctx, sel.SelectBlockOffset())
pushedGroupExpr := memo.NewGroupExpr(pushedSel)
pushedGroupExpr.SetChildren(childGroup)
pushedGroup := memo.NewGroupWithSchema(pushedGroupExpr, childGroup.Prop.Schema)

aggGroupExpr := memo.NewGroupExpr(agg)
aggGroupExpr.SetChildren(pushedGroup)

if len(remainedExprs) == 0 {
return []*memo.GroupExpr{aggGroupExpr}, true, false, nil
}

aggGroup := memo.NewGroupWithSchema(aggGroupExpr, agg.Schema())
remainedSel := plannercore.LogicalSelection{Conditions: remainedExprs}.Init(sctx, sel.SelectBlockOffset())
remainedGroupExpr := memo.NewGroupExpr(remainedSel)
remainedGroupExpr.SetChildren(aggGroup)
return []*memo.GroupExpr{remainedGroupExpr}, true, false, nil
}
1 change: 1 addition & 0 deletions planner/cascades/transformation_rules_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func (s *testTransformationRuleSuite) TestPredicatePushDown(c *C) {
rulePushSelDownTableGather,
rulePushSelDownSort,
rulePushSelDownProjection,
rulePushSelDownAggregation,
},
memo.OperandDataSource: {
ruleEnumeratePaths,
Expand Down
10 changes: 10 additions & 0 deletions planner/core/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,16 @@ type LogicalAggregation struct {
inputCount float64 // inputCount is the input count of this plan.
}

// GetGroupByCols returns the groupByCols. If the groupByCols haven't be collected,
// this method would collect them at first. If the GroupByItems have been changed,
// we should explicitly collect GroupByColumns before this method.
func (la *LogicalAggregation) GetGroupByCols() []*expression.Column {
if la.groupByCols == nil {
la.collectGroupByColumns()
}
return la.groupByCols
}

func (la *LogicalAggregation) extractCorrelatedCols() []*expression.CorrelatedColumn {
corCols := la.baseLogicalPlan.extractCorrelatedCols()
for _, expr := range la.GroupByItems {
Expand Down
25 changes: 0 additions & 25 deletions planner/core/physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,31 +343,6 @@ type PhysicalUnionAll struct {
physicalSchemaProducer
}

// AggregationType stands for the mode of aggregation plan.
type AggregationType int

const (
// StreamedAgg supposes its input is sorted by group by key.
StreamedAgg AggregationType = iota
// FinalAgg supposes its input is partial results.
FinalAgg
// CompleteAgg supposes its input is original results.
CompleteAgg
)

// String implements fmt.Stringer interface.
func (at AggregationType) String() string {
switch at {
case StreamedAgg:
return "stream"
case FinalAgg:
return "final"
case CompleteAgg:
return "complete"
}
return "unsupported aggregation type"
}

type basePhysicalAgg struct {
physicalSchemaProducer

Expand Down

0 comments on commit 5793040

Please sign in to comment.