Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: fix agg elimination logic after agg pushed down through a join (#44941) #45096

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1636,6 +1636,54 @@ func TestIssue27751(t *testing.T) {
tk.MustQuery("select group_concat(nname order by 1 desc separator '#' ) from t;").Check(testkit.Rows("33#2"))
}

func TestIssue44795(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec(`use test`)
tk.MustExec(`DROP TABLE IF EXISTS c`)

// case from tcph.
tk.MustExec("CREATE TABLE `customer` (" +
" `C_CUSTKEY` bigint(20) NOT NULL," +
" `C_NAME` varchar(25) NOT NULL," +
" `C_ADDRESS` varchar(40) NOT NULL," +
" `C_NATIONKEY` bigint(20) NOT NULL," +
" `C_PHONE` char(15) NOT NULL," +
" `C_ACCTBAL` decimal(15,2) NOT NULL," +
" `C_MKTSEGMENT` char(10) NOT NULL," +
" `C_COMMENT` varchar(117) NOT NULL," +
" PRIMARY KEY (`C_CUSTKEY`) /*T![clustered_index] CLUSTERED */" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")

tk.MustExec("CREATE TABLE `orders` (" +
" `O_ORDERKEY` bigint(20) NOT NULL," +
" `O_CUSTKEY` bigint(20) NOT NULL," +
" `O_ORDERSTATUS` char(1) NOT NULL," +
" `O_TOTALPRICE` decimal(15,2) NOT NULL," +
" `O_ORDERDATE` date NOT NULL," +
" `O_ORDERPRIORITY` char(15) NOT NULL," +
" `O_CLERK` char(15) NOT NULL," +
" `O_SHIPPRIORITY` bigint(20) NOT NULL," +
" `O_COMMENT` varchar(79) NOT NULL," +
" PRIMARY KEY (`O_ORDERKEY`) /*T![clustered_index] CLUSTERED */" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")

tk.MustExec("set tidb_opt_agg_push_down=ON;")

tk.MustQuery("explain format='brief' SELECT /*+ hash_join_build(customer) */ c_custkey, count(o_orderkey) as c_count from customer " +
"left join orders on c_custkey = o_custkey and o_comment not like '%special%requests%' group by c_custkey;").Check(testkit.Rows(
"Projection 8000.00 root test.customer.c_custkey, Column#18",
"└─HashAgg 8000.00 root group by:test.customer.c_custkey, funcs:count(Column#19)->Column#18, funcs:firstrow(test.customer.c_custkey)->test.customer.c_custkey",
" └─HashJoin 10000.00 root left outer join, equal:[eq(test.customer.c_custkey, test.orders.o_custkey)]",
" ├─TableReader(Build) 10000.00 root data:TableFullScan",
" │ └─TableFullScan 10000.00 cop[tikv] table:customer keep order:false, stats:pseudo",
" └─HashAgg(Probe) 6400.00 root group by:test.orders.o_custkey, funcs:count(Column#20)->Column#19, funcs:firstrow(test.orders.o_custkey)->test.orders.o_custkey",
" └─TableReader 6400.00 root data:HashAgg",
" └─HashAgg 6400.00 cop[tikv] group by:test.orders.o_custkey, funcs:count(test.orders.o_orderkey)->Column#20",
" └─Selection 8000.00 cop[tikv] not(like(test.orders.o_comment, \"%special%requests%\", 92))",
" └─TableFullScan 10000.00 cop[tikv] table:orders keep order:false, stats:pseudo"))
}

func TestIssue26885(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
Expand Down
8 changes: 4 additions & 4 deletions planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ func TestAggPushDownLeftJoin(t *testing.T) {
"on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows("6 0"))
tk.MustQuery("explain format='brief' select c_custkey, count(o_orderkey) as c_count from customer left outer join orders " +
"on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows(
"Projection 10000.00 root test.customer.c_custkey, Column#7",
"└─Projection 10000.00 root if(isnull(Column#8), 0, 1)->Column#7, test.customer.c_custkey",
"Projection 8000.00 root test.customer.c_custkey, Column#7",
"└─HashAgg 8000.00 root group by:test.customer.c_custkey, funcs:count(Column#8)->Column#7, funcs:firstrow(test.customer.c_custkey)->test.customer.c_custkey",
" └─HashJoin 10000.00 root left outer join, equal:[eq(test.customer.c_custkey, test.orders.o_custkey)]",
" ├─HashAgg(Build) 8000.00 root group by:test.orders.o_custkey, funcs:count(Column#9)->Column#8, funcs:firstrow(test.orders.o_custkey)->test.orders.o_custkey",
" │ └─TableReader 8000.00 root data:HashAgg",
Expand All @@ -152,8 +152,8 @@ func TestAggPushDownLeftJoin(t *testing.T) {
"on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows("6 0"))
tk.MustQuery("explain format='brief' select c_custkey, count(o_orderkey) as c_count from orders right outer join customer " +
"on c_custkey = o_custkey group by c_custkey").Check(testkit.Rows(
"Projection 10000.00 root test.customer.c_custkey, Column#7",
"└─Projection 10000.00 root if(isnull(Column#8), 0, 1)->Column#7, test.customer.c_custkey",
"Projection 8000.00 root test.customer.c_custkey, Column#7",
"└─HashAgg 8000.00 root group by:test.customer.c_custkey, funcs:count(Column#8)->Column#7, funcs:firstrow(test.customer.c_custkey)->test.customer.c_custkey",
" └─HashJoin 10000.00 root right outer join, equal:[eq(test.orders.o_custkey, test.customer.c_custkey)]",
" ├─HashAgg(Build) 8000.00 root group by:test.orders.o_custkey, funcs:count(Column#9)->Column#8, funcs:firstrow(test.orders.o_custkey)->test.orders.o_custkey",
" │ └─TableReader 8000.00 root data:HashAgg",
Expand Down
43 changes: 42 additions & 1 deletion planner/core/rule_aggregation_elimination.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,23 @@ type aggregationEliminator struct {
}

type aggregationEliminateChecker struct {
// used for agg pushed down cases, for example:
// agg -> join -> datasource1
// -> datasource2
// we just make a new agg upon datasource1 or datasource2, while the old agg is still existed and waiting for elimination.
// Note when the old agg is like below, and join is an outer join type, rewriting old agg in elimination logic has some problem.
// eg:
// count(a) -> ifnull(col#x, 0, 1) in rewriteExpr of agg function, since col#x is already the final pushed-down aggregation's
// result from new join schema, we don't need to take every row as count 1 when they don't have not-null flag in a.tryToEliminateAggregation(oldAgg, opt),
// which is not suitable here.
oldAggEliminationCheck bool
}

// tryToEliminateAggregation will eliminate aggregation grouped by unique key.
// e.g. select min(b) from t group by a. If a is a unique key, then this sql is equal to `select b from t group by a`.
// For count(expr), sum(expr), avg(expr), count(distinct expr, [expr...]) we may need to rewrite the expr. Details are shown below.
// If we can eliminate agg successful, we return a projection. Else we return a nil pointer.
func (*aggregationEliminateChecker) tryToEliminateAggregation(agg *LogicalAggregation, opt *logicalOptimizeOp) *LogicalProjection {
func (a *aggregationEliminateChecker) tryToEliminateAggregation(agg *LogicalAggregation, opt *logicalOptimizeOp) *LogicalProjection {
for _, af := range agg.AggFuncs {
// TODO(issue #9968): Actually, we can rewrite GROUP_CONCAT when all the
// arguments it accepts are promised to be NOT-NULL.
Expand All @@ -64,6 +74,9 @@ func (*aggregationEliminateChecker) tryToEliminateAggregation(agg *LogicalAggreg
}
}
if coveredByUniqueKey {
if a.oldAggEliminationCheck && !CheckCanConvertAggToProj(agg) {
return nil
}
// GroupByCols has unique key, so this aggregation can be removed.
if ok, proj := ConvertAggToProj(agg, agg.schema); ok {
proj.SetChildren(agg.children[0])
Expand Down Expand Up @@ -138,6 +151,34 @@ func appendDistinctEliminateTraceStep(agg *LogicalAggregation, uniqueKey express
opt.appendStepToCurrent(agg.ID(), agg.TP(), reason, action)
}

// CheckCanConvertAggToProj check whether a special old aggregation (which has already been pushed down) to projection.
// link: issue#44795
func CheckCanConvertAggToProj(agg *LogicalAggregation) bool {
var mayNullSchema *expression.Schema
if join, ok := agg.Children()[0].(*LogicalJoin); ok {
if join.JoinType == LeftOuterJoin {
mayNullSchema = join.Children()[1].Schema()
}
if join.JoinType == RightOuterJoin {
mayNullSchema = join.Children()[0].Schema()
}
if mayNullSchema == nil {
return true
}
// once agg function args has intersection with mayNullSchema, return nil (means elimination fail)
for _, fun := range agg.AggFuncs {
mayNullCols := expression.ExtractColumnsFromExpressions(nil, fun.Args, func(column *expression.Column) bool {
// collect may-null cols.
return mayNullSchema.Contains(column)
})
if len(mayNullCols) != 0 {
return false
}
}
}
return true
}

// ConvertAggToProj convert aggregation to projection.
func ConvertAggToProj(agg *LogicalAggregation, schema *expression.Schema) (bool, *LogicalProjection) {
proj := LogicalProjection{
Expand Down
31 changes: 30 additions & 1 deletion planner/core/rule_aggregation_push_down.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func (a *aggregationPushDownSolver) tryToPushDownAgg(oldAgg *LogicalAggregation,
}
tmpSchema := expression.NewSchema(gbyCols...)
for _, key := range child.Schema().Keys {
if tmpSchema.ColumnsIndices(key) != nil {
if tmpSchema.ColumnsIndices(key) != nil { // gby item need to be covered by key.
return child, nil
}
}
Expand Down Expand Up @@ -504,10 +504,39 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan, opt *logicalOptim
resetNotNullFlag(join.schema, 0, lChild.Schema().Len())
}
buildKeyInfo(join)
// count(a) -> ifnull(col#x, 0, 1) in rewriteExpr of agg function, since col#x is already the final
// pushed-down aggregation's result, we don't need to take every row as count 1 when they don't have
// not-null flag in a.tryToEliminateAggregation(oldAgg, opt), which is not suitable here.
oldCheck := a.oldAggEliminationCheck
a.oldAggEliminationCheck = true
proj := a.tryToEliminateAggregation(agg, opt)
if proj != nil {
p = proj
}
a.oldAggEliminationCheck = oldCheck

// Combine the aggregation elimination logic below since new agg's child key info has changed.
// Notice that even if we eliminate new agg below if possible, the agg's schema is inherited by proj.
// Therefore, we don't need to set the join's schema again, just build the keyInfo again.
changed := false
if newAgg, ok1 := lChild.(*LogicalAggregation); ok1 {
proj := a.tryToEliminateAggregation(newAgg, opt)
if proj != nil {
lChild = proj
changed = true
}
}
if newAgg, ok2 := rChild.(*LogicalAggregation); ok2 {
proj := a.tryToEliminateAggregation(newAgg, opt)
if proj != nil {
rChild = proj
changed = true
}
}
if changed {
join.SetChildren(lChild, rChild)
buildKeyInfo(join)
}
}
} else if proj, ok1 := child.(*LogicalProjection); ok1 {
// push aggregation across projection
Expand Down