diff --git a/cmd/explaintest/portgenerator b/cmd/explaintest/portgenerator deleted file mode 100755 index 74840b24142ad..0000000000000 Binary files a/cmd/explaintest/portgenerator and /dev/null differ diff --git a/cmd/explaintest/r/explain_easy.result b/cmd/explaintest/r/explain_easy.result index f2970e9916ab2..41ae2ceaa6e1f 100644 --- a/cmd/explaintest/r/explain_easy.result +++ b/cmd/explaintest/r/explain_easy.result @@ -652,9 +652,9 @@ drop table if exists t; create table t(a int, b int, c int); explain select * from (select * from t order by (select 2)) t order by a, b; id estRows task access object operator info -Sort_12 10000.00 root test.t.a:asc, test.t.b:asc -└─TableReader_18 10000.00 root data:TableFullScan_17 - └─TableFullScan_17 10000.00 cop[tikv] table:t keep order:false, stats:pseudo +Sort_13 10000.00 root test.t.a:asc, test.t.b:asc +└─TableReader_19 10000.00 root data:TableFullScan_18 + └─TableFullScan_18 10000.00 cop[tikv] table:t keep order:false, stats:pseudo explain select * from (select * from t order by c) t order by a, b; id estRows task access object operator info Sort_6 10000.00 root test.t.a:asc, test.t.b:asc @@ -784,3 +784,72 @@ Update_4 N/A root N/A ├─IndexRangeScan_9(Build) 0.10 cop[tikv] table:t, index:a(a, b) range:["[250 52 225 9 60 180 40 72 87 52 227 145 127 0 0 0]" "xb","[250 52 225 9 60 180 40 72 87 52 227 145 127 0 0 0]" "xb"], keep order:false, stats:pseudo └─TableRowIDScan_10(Probe) 0.10 cop[tikv] table:t keep order:false, stats:pseudo drop table if exists t; +create table t(a int, b int); +explain select (select count(n.a) from t) from t n; +id estRows task access object operator info +Projection_9 1.00 root Column#8 +└─Apply_11 1.00 root CARTESIAN left outer join + ├─StreamAgg_23(Build) 1.00 root funcs:count(Column#13)->Column#7 + │ └─TableReader_24 1.00 root data:StreamAgg_15 + │ └─StreamAgg_15 1.00 cop[tikv] funcs:count(test.t.a)->Column#13 + │ └─TableFullScan_22 10000.00 cop[tikv] table:n keep order:false, stats:pseudo + └─MaxOneRow_27(Probe) 1.00 root + └─Projection_28 2.00 root Column#7 + └─TableReader_30 2.00 root data:TableFullScan_29 + └─TableFullScan_29 2.00 cop[tikv] table:t keep order:false, stats:pseudo +explain select (select sum((select count(a)))) from t; +id estRows task access object operator info +Projection_23 1.00 root Column#7 +└─Apply_25 1.00 root CARTESIAN left outer join + ├─StreamAgg_37(Build) 1.00 root funcs:count(Column#15)->Column#5 + │ └─TableReader_38 1.00 root data:StreamAgg_29 + │ └─StreamAgg_29 1.00 cop[tikv] funcs:count(test.t.a)->Column#15 + │ └─TableFullScan_36 10000.00 cop[tikv] table:t keep order:false, stats:pseudo + └─HashAgg_43(Probe) 1.00 root funcs:sum(Column#12)->Column#7 + └─HashJoin_44 1.00 root CARTESIAN left outer join + ├─HashAgg_49(Build) 1.00 root group by:1, funcs:sum(Column#16)->Column#12 + │ └─Projection_54 1.00 root cast(Column#6, decimal(65,0) BINARY)->Column#16 + │ └─MaxOneRow_50 1.00 root + │ └─Projection_51 1.00 root Column#5 + │ └─TableDual_52 1.00 root rows:1 + └─TableDual_46(Probe) 1.00 root rows:1 +explain select count(a) from t group by b order by (select count(a)); +id estRows task access object operator info +Projection_11 8000.00 root Column#4 +└─Sort_12 8000.00 root Column#5:asc + └─Apply_15 8000.00 root CARTESIAN left outer join + ├─HashAgg_20(Build) 8000.00 root group by:test.t.b, funcs:count(Column#9)->Column#4 + │ └─TableReader_21 8000.00 root data:HashAgg_16 + │ └─HashAgg_16 8000.00 cop[tikv] group by:test.t.b, funcs:count(test.t.a)->Column#9 + │ └─TableFullScan_19 10000.00 cop[tikv] table:t keep order:false, stats:pseudo + └─MaxOneRow_24(Probe) 1.00 root + └─Projection_25 1.00 root Column#4 + └─TableDual_26 1.00 root rows:1 +explain select (select sum(count(a))) from t; +id estRows task access object operator info +Projection_11 1.00 root Column#5 +└─Apply_13 1.00 root CARTESIAN left outer join + ├─StreamAgg_25(Build) 1.00 root funcs:count(Column#8)->Column#4 + │ └─TableReader_26 1.00 root data:StreamAgg_17 + │ └─StreamAgg_17 1.00 cop[tikv] funcs:count(test.t.a)->Column#8 + │ └─TableFullScan_24 10000.00 cop[tikv] table:t keep order:false, stats:pseudo + └─StreamAgg_32(Probe) 1.00 root funcs:sum(Column#9)->Column#5 + └─Projection_39 1.00 root cast(Column#4, decimal(65,0) BINARY)->Column#9 + └─TableDual_37 1.00 root rows:1 +explain select sum(a), (select sum(a)), count(a) from t group by b order by (select count(a)); +id estRows task access object operator info +Projection_16 8000.00 root Column#4, Column#6, Column#5 +└─Sort_17 8000.00 root Column#7:asc + └─Apply_20 8000.00 root CARTESIAN left outer join + ├─Apply_22(Build) 8000.00 root CARTESIAN left outer join + │ ├─HashAgg_27(Build) 8000.00 root group by:test.t.b, funcs:sum(Column#16)->Column#4, funcs:count(Column#17)->Column#5 + │ │ └─TableReader_28 8000.00 root data:HashAgg_23 + │ │ └─HashAgg_23 8000.00 cop[tikv] group by:test.t.b, funcs:sum(test.t.a)->Column#16, funcs:count(test.t.a)->Column#17 + │ │ └─TableFullScan_26 10000.00 cop[tikv] table:t keep order:false, stats:pseudo + │ └─MaxOneRow_31(Probe) 1.00 root + │ └─Projection_32 1.00 root Column#4 + │ └─TableDual_33 1.00 root rows:1 + └─MaxOneRow_34(Probe) 1.00 root + └─Projection_35 1.00 root Column#5 + └─TableDual_36 1.00 root rows:1 +drop table if exists t; diff --git a/cmd/explaintest/r/tpch.result b/cmd/explaintest/r/tpch.result index 21fd1330134e3..ca58a15f76065 100644 --- a/cmd/explaintest/r/tpch.result +++ b/cmd/explaintest/r/tpch.result @@ -711,10 +711,10 @@ and n_name = 'MOZAMBIQUE' order by value desc; id estRows task access object operator info -Projection_52 1304801.67 root tpch.partsupp.ps_partkey, Column#18 -└─Sort_53 1304801.67 root Column#18:desc - └─Selection_55 1304801.67 root gt(Column#18, NULL) - └─HashAgg_58 1631002.09 root group by:Column#44, funcs:sum(Column#42)->Column#18, funcs:firstrow(Column#43)->tpch.partsupp.ps_partkey +Projection_52 1304801.67 root tpch.partsupp.ps_partkey, Column#35 +└─Sort_53 1304801.67 root Column#35:desc + └─Selection_55 1304801.67 root gt(Column#35, NULL) + └─HashAgg_58 1631002.09 root group by:Column#44, funcs:sum(Column#42)->Column#35, funcs:firstrow(Column#43)->tpch.partsupp.ps_partkey └─Projection_79 1631002.09 root mul(tpch.partsupp.ps_supplycost, cast(tpch.partsupp.ps_availqty, decimal(20,0) BINARY))->Column#42, tpch.partsupp.ps_partkey, tpch.partsupp.ps_partkey └─HashJoin_62 1631002.09 root inner join, equal:[eq(tpch.supplier.s_suppkey, tpch.partsupp.ps_suppkey)] ├─HashJoin_70(Build) 20000.00 root inner join, equal:[eq(tpch.nation.n_nationkey, tpch.supplier.s_nationkey)] diff --git a/cmd/explaintest/t/explain_easy.test b/cmd/explaintest/t/explain_easy.test index 930f097d9d2b3..a01179dd3ddbe 100644 --- a/cmd/explaintest/t/explain_easy.test +++ b/cmd/explaintest/t/explain_easy.test @@ -210,3 +210,11 @@ create table t(a binary(16) not null, b varchar(2) default null, c varchar(100) explain select * from t where a=x'FA34E1093CB428485734E3917F000000' and b='xb'; explain update t set c = 'ssss' where a=x'FA34E1093CB428485734E3917F000000' and b='xb'; drop table if exists t; + +create table t(a int, b int); +explain select (select count(n.a) from t) from t n; +explain select (select sum((select count(a)))) from t; +explain select count(a) from t group by b order by (select count(a)); +explain select (select sum(count(a))) from t; +explain select sum(a), (select sum(a)), count(a) from t group by b order by (select count(a)); +drop table if exists t; diff --git a/executor/testdata/agg_suite_out.json b/executor/testdata/agg_suite_out.json index ab65db16fb991..73d920dbbbb3c 100644 --- a/executor/testdata/agg_suite_out.json +++ b/executor/testdata/agg_suite_out.json @@ -49,21 +49,21 @@ "Name": "TestIssue12759HashAggCalledByApply", "Cases": [ [ - "Projection_28 1.00 root Column#3, Column#6, Column#9, Column#12", + "Projection_28 1.00 root Column#9, Column#10, Column#11, Column#12", "└─Apply_30 1.00 root CARTESIAN left outer join", " ├─Apply_32(Build) 1.00 root CARTESIAN left outer join", " │ ├─Apply_34(Build) 1.00 root CARTESIAN left outer join", - " │ │ ├─HashAgg_39(Build) 1.00 root funcs:sum(Column#22)->Column#3, funcs:firstrow(Column#23)->test.test.a", + " │ │ ├─HashAgg_39(Build) 1.00 root funcs:sum(Column#22)->Column#9, funcs:firstrow(Column#23)->test.test.a", " │ │ │ └─TableReader_40 1.00 root data:HashAgg_35", " │ │ │ └─HashAgg_35 1.00 cop[tikv] funcs:sum(test.test.a)->Column#22, funcs:firstrow(test.test.a)->Column#23", " │ │ │ └─TableFullScan_38 10000.00 cop[tikv] table:tt keep order:false, stats:pseudo", - " │ │ └─Projection_43(Probe) 1.00 root ->Column#6", + " │ │ └─Projection_43(Probe) 1.00 root ->Column#10", " │ │ └─Limit_44 1.00 root offset:0, count:1", " │ │ └─TableReader_50 1.00 root data:Limit_49", " │ │ └─Limit_49 1.00 cop[tikv] offset:0, count:1", " │ │ └─Selection_48 1.00 cop[tikv] eq(test.test.a, test.test.a)", " │ │ └─TableFullScan_47 1000.00 cop[tikv] table:test keep order:false, stats:pseudo", - " │ └─Projection_54(Probe) 1.00 root ->Column#9", + " │ └─Projection_54(Probe) 1.00 root ->Column#11", " │ └─Limit_55 1.00 root offset:0, count:1", " │ └─TableReader_61 1.00 root data:Limit_60", " │ └─Limit_60 1.00 cop[tikv] offset:0, count:1", diff --git a/planner/cascades/stringer_test.go b/planner/cascades/stringer_test.go index 85314983548c9..c9a64772d8145 100644 --- a/planner/cascades/stringer_test.go +++ b/planner/cascades/stringer_test.go @@ -86,6 +86,6 @@ func (s *testStringerSuite) TestGroupStringer(c *C) { output[i].SQL = sql output[i].Result = ToString(group) }) - c.Assert(ToString(group), DeepEquals, output[i].Result) + c.Assert(ToString(group), DeepEquals, output[i].Result, Commentf("case:%v, sql:%s", i, sql)) } } diff --git a/planner/cascades/testdata/integration_suite_out.json b/planner/cascades/testdata/integration_suite_out.json index c3df6622d4dcb..bc23ef35aaf56 100644 --- a/planner/cascades/testdata/integration_suite_out.json +++ b/planner/cascades/testdata/integration_suite_out.json @@ -1019,10 +1019,10 @@ { "SQL": "select sum(a), (select t1.a from t1 where t1.a = t2.a limit 1), (select t1.b from t1 where t1.b = t2.b limit 1) from t2", "Plan": [ - "Projection_28 1.00 root Column#3, test.t1.a, test.t1.b", + "Projection_28 1.00 root Column#7, test.t1.a, test.t1.b", "└─Apply_30 1.00 root CARTESIAN left outer join", " ├─Apply_32(Build) 1.00 root CARTESIAN left outer join", - " │ ├─HashAgg_37(Build) 1.00 root funcs:sum(Column#8)->Column#3, funcs:firstrow(Column#9)->test.t2.a, funcs:firstrow(Column#10)->test.t2.b", + " │ ├─HashAgg_37(Build) 1.00 root funcs:sum(Column#8)->Column#7, funcs:firstrow(Column#9)->test.t2.a, funcs:firstrow(Column#10)->test.t2.b", " │ │ └─TableReader_38 1.00 root data:HashAgg_39", " │ │ └─HashAgg_39 1.00 cop[tikv] funcs:sum(test.t2.a)->Column#8, funcs:firstrow(test.t2.a)->Column#9, funcs:firstrow(test.t2.b)->Column#10", " │ │ └─TableFullScan_35 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo", diff --git a/planner/cascades/testdata/stringer_suite_out.json b/planner/cascades/testdata/stringer_suite_out.json index 61bd297511785..dbcc00aec395f 100644 --- a/planner/cascades/testdata/stringer_suite_out.json +++ b/planner/cascades/testdata/stringer_suite_out.json @@ -290,7 +290,7 @@ "SQL": "select a = (select a from t t2 where t1.b = t2.b order by a limit 1) from t t1", "Result": [ "Group#0 Schema:[Column#25]", - " Projection_2 input:[Group#1], eq(test.t.a, test.t.a)->Column#25", + " Projection_3 input:[Group#1], eq(test.t.a, test.t.a)->Column#25", "Group#1 Schema:[test.t.a,test.t.b,test.t.a]", " Apply_9 input:[Group#2,Group#3], left outer join", "Group#2 Schema:[test.t.a,test.t.b], UniqueKey:[test.t.a]", diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index f66e9b614e654..f7bfb4783b32e 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -329,11 +329,24 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { if er.aggrMap != nil { index, ok = er.aggrMap[v] } - if !ok { - er.err = ErrInvalidGroupFuncUse + if ok { + // index < 0 indicates this is a correlated aggregate belonging to outer query, + // for which a correlated column will be created later, so we append a null constant + // as a temporary result expression. + if index < 0 { + er.ctxStackAppend(expression.NewNull(), types.EmptyName) + } else { + // index >= 0 indicates this is a regular aggregate column + er.ctxStackAppend(er.schema.Columns[index], er.names[index]) + } return inNode, true } - er.ctxStackAppend(er.schema.Columns[index], er.names[index]) + // replace correlated aggregate in sub-query with its corresponding correlated column + if col, ok := er.b.correlatedAggMapper[v]; ok { + er.ctxStackAppend(col, types.EmptyName) + return inNode, true + } + er.err = ErrInvalidGroupFuncUse return inNode, true case *ast.ColumnNameExpr: if index, ok := er.b.colMapper[v]; ok { diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 563a018336530..8414956ab30c5 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -1626,6 +1626,86 @@ func (s *testIntegrationSuite) TestUpdateMultiUpdatePK(c *C) { tk.MustQuery("SELECT * FROM t").Check(testkit.Rows("2 12")) } +func (s *testIntegrationSuite) TestCorrelatedAggregate(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + + // #18350 + tk.MustExec("DROP TABLE IF EXISTS tab, tab2") + tk.MustExec("CREATE TABLE tab(i INT)") + tk.MustExec("CREATE TABLE tab2(j INT)") + tk.MustExec("insert into tab values(1),(2),(3)") + tk.MustExec("insert into tab2 values(1),(2),(3),(15)") + tk.MustQuery(`SELECT m.i, + (SELECT COUNT(n.j) + FROM tab2 WHERE j=15) AS o + FROM tab m, tab2 n GROUP BY 1 order by m.i`).Check(testkit.Rows("1 4", "2 4", "3 4")) + tk.MustQuery(`SELECT + (SELECT COUNT(n.j) + FROM tab2 WHERE j=15) AS o + FROM tab m, tab2 n order by m.i`).Check(testkit.Rows("12")) + + // #17748 + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1 (a int, b int)") + tk.MustExec("create table t2 (m int, n int)") + tk.MustExec("insert into t1 values (2,2), (2,2), (3,3), (3,3), (3,3), (4,4)") + tk.MustExec("insert into t2 values (1,11), (2,22), (3,32), (4,44), (4,44)") + tk.MustExec("set @@sql_mode='TRADITIONAL'") + + tk.MustQuery(`select count(*) c, a, + ( select group_concat(count(a)) from t2 where m = a ) + from t1 group by a order by a`). + Check(testkit.Rows("2 2 2", "3 3 3", "1 4 1,1")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int, b int)") + tk.MustExec("insert into t values (1,1),(2,1),(2,2),(3,1),(3,2),(3,3)") + + // Sub-queries in SELECT fields + // from SELECT fields + tk.MustQuery("select (select count(a)) from t").Check(testkit.Rows("6")) + tk.MustQuery("select (select (select (select count(a)))) from t").Check(testkit.Rows("6")) + tk.MustQuery("select (select (select count(n.a)) from t m order by count(m.b)) from t n").Check(testkit.Rows("6")) + // from WHERE + tk.MustQuery("select (select count(n.a) from t where count(n.a)=3) from t n").Check(testkit.Rows("")) + tk.MustQuery("select (select count(a) from t where count(distinct n.a)=3) from t n").Check(testkit.Rows("6")) + // from HAVING + tk.MustQuery("select (select count(n.a) from t having count(n.a)=6 limit 1) from t n").Check(testkit.Rows("6")) + tk.MustQuery("select (select count(n.a) from t having count(distinct n.b)=3 limit 1) from t n").Check(testkit.Rows("6")) + tk.MustQuery("select (select sum(distinct n.a) from t having count(distinct n.b)=3 limit 1) from t n").Check(testkit.Rows("6")) + tk.MustQuery("select (select sum(distinct n.a) from t having count(distinct n.b)=6 limit 1) from t n").Check(testkit.Rows("")) + // from ORDER BY + tk.MustQuery("select (select count(n.a) from t order by count(n.b) limit 1) from t n").Check(testkit.Rows("6")) + tk.MustQuery("select (select count(distinct n.b) from t order by count(n.b) limit 1) from t n").Check(testkit.Rows("3")) + // from TableRefsClause + tk.MustQuery("select (select cnt from (select count(a) cnt) s) from t").Check(testkit.Rows("6")) + tk.MustQuery("select (select count(cnt) from (select count(a) cnt) s) from t").Check(testkit.Rows("1")) + // from sub-query inside aggregate + tk.MustQuery("select (select sum((select count(a)))) from t").Check(testkit.Rows("6")) + tk.MustQuery("select (select sum((select count(a))+sum(a))) from t").Check(testkit.Rows("20")) + // from GROUP BY + tk.MustQuery("select (select count(a) from t group by count(n.a)) from t n").Check(testkit.Rows("6")) + tk.MustQuery("select (select count(distinct a) from t group by count(n.a)) from t n").Check(testkit.Rows("3")) + + // Sub-queries in HAVING + tk.MustQuery("select sum(a) from t having (select count(a)) = 0").Check(testkit.Rows()) + tk.MustQuery("select sum(a) from t having (select count(a)) > 0").Check(testkit.Rows("14")) + + // Sub-queries in ORDER BY + tk.MustQuery("select count(a) from t group by b order by (select count(a))").Check(testkit.Rows("1", "2", "3")) + tk.MustQuery("select count(a) from t group by b order by (select -count(a))").Check(testkit.Rows("3", "2", "1")) + + // Nested aggregate (correlated aggregate inside aggregate) + tk.MustQuery("select (select sum(count(a))) from t").Check(testkit.Rows("6")) + tk.MustQuery("select (select sum(sum(a))) from t").Check(testkit.Rows("14")) + + // Combining aggregates + tk.MustQuery("select count(a), (select count(a)) from t").Check(testkit.Rows("6 6")) + tk.MustQuery("select sum(distinct b), count(a), (select count(a)), (select cnt from (select sum(distinct b) as cnt) n) from t"). + Check(testkit.Rows("6 6 6 6")) +} + func (s *testIntegrationSuite) TestInvalidNamedWindowSpec(c *C) { // #12356 tk := testkit.NewTestKit(c, s.store) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index d48a747e59260..ca66fd458c536 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -171,7 +171,8 @@ func (a *aggOrderByResolver) Leave(inNode ast.Node) (ast.Node, bool) { return inNode, true } -func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFuncList []*ast.AggregateFuncExpr, gbyItems []expression.Expression) (LogicalPlan, map[int]int, error) { +func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFuncList []*ast.AggregateFuncExpr, gbyItems []expression.Expression, + correlatedAggMap map[*ast.AggregateFuncExpr]int) (LogicalPlan, map[int]int, error) { b.optFlag |= flagBuildKeyInfo b.optFlag |= flagPushDownAgg // We may apply aggregation eliminate optimization. @@ -231,23 +232,40 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu newFunc.OrderByItems = append(newFunc.OrderByItems, &util.ByItems{Expr: newByItem, Desc: byItem.Desc}) } } + // combine identical aggregate functions combined := false - for j, oldFunc := range plan4Agg.AggFuncs { + for j := 0; j < i; j++ { + oldFunc := plan4Agg.AggFuncs[aggIndexMap[j]] if oldFunc.Equal(b.ctx, newFunc) { - aggIndexMap[i] = j + aggIndexMap[i] = aggIndexMap[j] combined = true + if _, ok := correlatedAggMap[aggFunc]; ok { + if _, ok = b.correlatedAggMapper[aggFuncList[j]]; !ok { + b.correlatedAggMapper[aggFuncList[j]] = &expression.CorrelatedColumn{ + Column: *schema4Agg.Columns[aggIndexMap[j]], + } + } + b.correlatedAggMapper[aggFunc] = b.correlatedAggMapper[aggFuncList[j]] + } break } } + // create new columns for aggregate functions which show up first if !combined { position := len(plan4Agg.AggFuncs) aggIndexMap[i] = position plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, newFunc) - schema4Agg.Append(&expression.Column{ + column := expression.Column{ UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), RetType: newFunc.RetTp, - }) + } + schema4Agg.Append(&column) names = append(names, types.EmptyName) + if _, ok := correlatedAggMap[aggFunc]; ok { + b.correlatedAggMapper[aggFunc] = &expression.CorrelatedColumn{ + Column: column, + } + } } } for i, col := range p.Schema().Columns { @@ -293,6 +311,34 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu return plan4Agg, aggIndexMap, nil } +func (b *PlanBuilder) buildTableRefsWithCache(ctx context.Context, from *ast.TableRefsClause) (p LogicalPlan, err error) { + return b.buildTableRefs(ctx, from, true) +} + +func (b *PlanBuilder) buildTableRefs(ctx context.Context, from *ast.TableRefsClause, useCache bool) (p LogicalPlan, err error) { + if from == nil { + p = b.buildTableDual() + return + } + if !useCache { + return b.buildResultSetNode(ctx, from.TableRefs) + } + var ok bool + p, ok = b.cachedResultSetNodes[from.TableRefs] + if ok { + m := b.cachedHandleHelperMap[from.TableRefs] + b.handleHelper.pushMap(m) + return + } + p, err = b.buildResultSetNode(ctx, from.TableRefs) + if err != nil { + return nil, err + } + b.cachedResultSetNodes[from.TableRefs] = p + b.cachedHandleHelperMap[from.TableRefs] = b.handleHelper.tailMap() + return +} + func (b *PlanBuilder) buildResultSetNode(ctx context.Context, node ast.ResultSetNode) (p LogicalPlan, err error) { switch x := node.(type) { case *ast.Join: @@ -1863,8 +1909,8 @@ func (b *PlanBuilder) resolveHavingAndOrderBy(sel *ast.SelectStmt, p LogicalPlan return havingAggMapper, extractor.aggMapper, nil } -func (b *PlanBuilder) extractAggFuncs(fields []*ast.SelectField) ([]*ast.AggregateFuncExpr, map[*ast.AggregateFuncExpr]int) { - extractor := &AggregateFuncExtractor{} +func (b *PlanBuilder) extractAggFuncsInSelectFields(fields []*ast.SelectField) ([]*ast.AggregateFuncExpr, map[*ast.AggregateFuncExpr]int) { + extractor := &AggregateFuncExtractor{skipAggMap: b.correlatedAggMapper} for _, f := range fields { n, _ := f.Expr.Accept(extractor) f.Expr = n.(ast.ExprNode) @@ -1878,6 +1924,38 @@ func (b *PlanBuilder) extractAggFuncs(fields []*ast.SelectField) ([]*ast.Aggrega return aggList, totalAggMapper } +func (b *PlanBuilder) extractAggFuncsInByItems(byItems []*ast.ByItem) []*ast.AggregateFuncExpr { + extractor := &AggregateFuncExtractor{skipAggMap: b.correlatedAggMapper} + for _, f := range byItems { + n, _ := f.Expr.Accept(extractor) + f.Expr = n.(ast.ExprNode) + } + return extractor.AggFuncs +} + +// extractCorrelatedAggFuncs extracts correlated aggregates which belong to outer query from aggregate function list. +func (b *PlanBuilder) extractCorrelatedAggFuncs(ctx context.Context, p LogicalPlan, aggFuncs []*ast.AggregateFuncExpr) (outer []*ast.AggregateFuncExpr, err error) { + corCols := make([]*expression.CorrelatedColumn, 0, len(aggFuncs)) + cols := make([]*expression.Column, 0, len(aggFuncs)) + aggMapper := make(map[*ast.AggregateFuncExpr]int) + for _, agg := range aggFuncs { + for _, arg := range agg.Args { + expr, _, err := b.rewrite(ctx, arg, p, aggMapper, true) + if err != nil { + return nil, err + } + corCols = append(corCols, expression.ExtractCorColumns(expr)...) + cols = append(cols, expression.ExtractColumns(expr)...) + } + if len(corCols) > 0 && len(cols) == 0 { + outer = append(outer, agg) + } + aggMapper[agg] = -1 + corCols, cols = corCols[:0], cols[:0] + } + return +} + // resolveWindowFunction will process window functions and resolve the columns that don't exist in select fields. func (b *PlanBuilder) resolveWindowFunction(sel *ast.SelectStmt, p LogicalPlan) ( map[*ast.AggregateFuncExpr]int, error) { @@ -1923,15 +2001,232 @@ func (b *PlanBuilder) resolveWindowFunction(sel *ast.SelectStmt, p LogicalPlan) return extractor.aggMapper, nil } +// correlatedAggregateResolver visits Expr tree. +// It finds and collects all correlated aggregates which should be evaluated in the outer query. +type correlatedAggregateResolver struct { + ctx context.Context + err error + b *PlanBuilder + outerPlan LogicalPlan + + // correlatedAggFuncs stores aggregate functions which belong to outer query + correlatedAggFuncs []*ast.AggregateFuncExpr +} + +// Enter implements Visitor interface. +func (r *correlatedAggregateResolver) Enter(n ast.Node) (ast.Node, bool) { + switch v := n.(type) { + case *ast.SelectStmt: + if r.outerPlan != nil { + outerSchema := r.outerPlan.Schema() + r.b.outerSchemas = append(r.b.outerSchemas, outerSchema) + r.b.outerNames = append(r.b.outerNames, r.outerPlan.OutputNames()) + } + r.err = r.resolveSelect(v) + return n, true + } + return n, false +} + +// resolveSelect finds and collects correlated aggregates within the SELECT stmt. +// It resolves and builds FROM clause first to get a source plan, from which we can decide +// whether a column is correlated or not. +// Then it collects correlated aggregate from SELECT fields (including sub-queries), HAVING, +// ORDER BY, WHERE & GROUP BY. +// Finally it restore the original SELECT stmt. +func (r *correlatedAggregateResolver) resolveSelect(sel *ast.SelectStmt) (err error) { + // collect correlated aggregate from sub-queries inside FROM clause. + useCache, err := r.collectFromTableRefs(r.ctx, sel.From) + if err != nil { + return err + } + // we cannot use cache if there are correlated aggregates inside FROM clause, + // since the plan we are building now is not correct and need to be rebuild later. + p, err := r.b.buildTableRefs(r.ctx, sel.From, useCache) + if err != nil { + return err + } + + // similar to process in PlanBuilder.buildSelect + originalFields := sel.Fields.Fields + sel.Fields.Fields, err = r.b.unfoldWildStar(p, sel.Fields.Fields) + if err != nil { + return err + } + if r.b.capFlag&canExpandAST != 0 { + originalFields = sel.Fields.Fields + } + + hasWindowFuncField := r.b.detectSelectWindow(sel) + if hasWindowFuncField { + _, err = r.b.resolveWindowFunction(sel, p) + if err != nil { + return err + } + } + + _, _, err = r.b.resolveHavingAndOrderBy(sel, p) + if err != nil { + return err + } + + // find and collect correlated aggregates recursively in sub-queries + _, err = r.b.resolveCorrelatedAggregates(r.ctx, sel, p) + if err != nil { + return err + } + + // collect from SELECT fields, HAVING, ORDER BY and window functions + if r.b.detectSelectAgg(sel) { + err = r.collectFromSelectFields(p, sel.Fields.Fields) + if err != nil { + return err + } + } + + // collect from WHERE + err = r.collectFromWhere(p, sel.Where) + if err != nil { + return err + } + + // collect from GROUP BY + err = r.collectFromGroupBy(p, sel.GroupBy) + if err != nil { + return err + } + + // restore the sub-query + sel.Fields.Fields = originalFields + r.b.handleHelper.popMap() + return nil +} + +func (r *correlatedAggregateResolver) collectFromTableRefs(ctx context.Context, from *ast.TableRefsClause) (canCache bool, err error) { + if from == nil { + return true, nil + } + subResolver := &correlatedAggregateResolver{ + ctx: r.ctx, + b: r.b, + } + _, ok := from.TableRefs.Accept(subResolver) + if !ok { + return false, subResolver.err + } + if len(subResolver.correlatedAggFuncs) == 0 { + return true, nil + } + r.correlatedAggFuncs = append(r.correlatedAggFuncs, subResolver.correlatedAggFuncs...) + return false, nil +} + +func (r *correlatedAggregateResolver) collectFromSelectFields(p LogicalPlan, fields []*ast.SelectField) error { + aggList, _ := r.b.extractAggFuncsInSelectFields(fields) + r.b.curClause = fieldList + outerAggFuncs, err := r.b.extractCorrelatedAggFuncs(r.ctx, p, aggList) + if err != nil { + return nil + } + r.correlatedAggFuncs = append(r.correlatedAggFuncs, outerAggFuncs...) + return nil +} + +func (r *correlatedAggregateResolver) collectFromGroupBy(p LogicalPlan, groupBy *ast.GroupByClause) error { + if groupBy == nil { + return nil + } + aggList := r.b.extractAggFuncsInByItems(groupBy.Items) + r.b.curClause = groupByClause + outerAggFuncs, err := r.b.extractCorrelatedAggFuncs(r.ctx, p, aggList) + if err != nil { + return nil + } + r.correlatedAggFuncs = append(r.correlatedAggFuncs, outerAggFuncs...) + return nil +} + +func (r *correlatedAggregateResolver) collectFromWhere(p LogicalPlan, where ast.ExprNode) error { + if where == nil { + return nil + } + extractor := &AggregateFuncExtractor{skipAggMap: r.b.correlatedAggMapper} + _, _ = where.Accept(extractor) + r.b.curClause = whereClause + outerAggFuncs, err := r.b.extractCorrelatedAggFuncs(r.ctx, p, extractor.AggFuncs) + if err != nil { + return err + } + r.correlatedAggFuncs = append(r.correlatedAggFuncs, outerAggFuncs...) + return nil +} + +// Leave implements Visitor interface. +func (r *correlatedAggregateResolver) Leave(n ast.Node) (ast.Node, bool) { + switch n.(type) { + case *ast.SelectStmt: + if r.outerPlan != nil { + r.b.outerSchemas = r.b.outerSchemas[0 : len(r.b.outerSchemas)-1] + r.b.outerNames = r.b.outerNames[0 : len(r.b.outerNames)-1] + } + } + return n, true +} + +// resolveCorrelatedAggregates finds and collects all correlated aggregates which should be evaluated +// in the outer query from all the sub-queries inside SELECT fields. +func (b *PlanBuilder) resolveCorrelatedAggregates(ctx context.Context, sel *ast.SelectStmt, p LogicalPlan) (map[*ast.AggregateFuncExpr]int, error) { + resolver := &correlatedAggregateResolver{ + ctx: ctx, + b: b, + outerPlan: p, + } + correlatedAggList := make([]*ast.AggregateFuncExpr, 0) + for _, field := range sel.Fields.Fields { + _, ok := field.Expr.Accept(resolver) + if !ok { + return nil, resolver.err + } + correlatedAggList = append(correlatedAggList, resolver.correlatedAggFuncs...) + } + if sel.Having != nil { + _, ok := sel.Having.Expr.Accept(resolver) + if !ok { + return nil, resolver.err + } + correlatedAggList = append(correlatedAggList, resolver.correlatedAggFuncs...) + } + if sel.OrderBy != nil { + for _, item := range sel.OrderBy.Items { + _, ok := item.Expr.Accept(resolver) + if !ok { + return nil, resolver.err + } + correlatedAggList = append(correlatedAggList, resolver.correlatedAggFuncs...) + } + } + correlatedAggMap := make(map[*ast.AggregateFuncExpr]int) + for _, aggFunc := range correlatedAggList { + correlatedAggMap[aggFunc] = len(sel.Fields.Fields) + sel.Fields.Fields = append(sel.Fields.Fields, &ast.SelectField{ + Auxiliary: true, + Expr: aggFunc, + AsName: model.NewCIStr(fmt.Sprintf("sel_subq_agg_%d", len(sel.Fields.Fields))), + }) + } + return correlatedAggMap, nil +} + // gbyResolver resolves group by items from select fields. type gbyResolver struct { - ctx sessionctx.Context - fields []*ast.SelectField - schema *expression.Schema - names []*types.FieldName - err error - inExpr bool - isParam bool + ctx sessionctx.Context + fields []*ast.SelectField + schema *expression.Schema + names []*types.FieldName + err error + inExpr bool + isParam bool + skipAggMap map[*ast.AggregateFuncExpr]*expression.CorrelatedColumn exprDepth int // exprDepth is the depth of current expression in expression tree. } @@ -1959,7 +2254,7 @@ func (g *gbyResolver) Enter(inNode ast.Node) (ast.Node, bool) { } func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { - extractor := &AggregateFuncExtractor{} + extractor := &AggregateFuncExtractor{skipAggMap: g.skipAggMap} switch v := inNode.(type) { case *ast.ColumnNameExpr: idx, err := expression.FindFieldName(g.names, v.Name) @@ -2434,10 +2729,11 @@ func (b *PlanBuilder) resolveGbyExprs(ctx context.Context, p LogicalPlan, gby *a b.curClause = groupByClause exprs := make([]expression.Expression, 0, len(gby.Items)) resolver := &gbyResolver{ - ctx: b.ctx, - fields: fields, - schema: p.Schema(), - names: p.OutputNames(), + ctx: b.ctx, + fields: fields, + schema: p.Schema(), + names: p.OutputNames(), + skipAggMap: b.correlatedAggMapper, } for _, item := range gby.Items { resolver.inExpr = false @@ -2733,6 +3029,7 @@ func (b *PlanBuilder) buildSelect(ctx context.Context, sel *ast.SelectStmt) (p L aggFuncs []*ast.AggregateFuncExpr havingMap, orderMap, totalMap map[*ast.AggregateFuncExpr]int windowAggMap map[*ast.AggregateFuncExpr]int + correlatedAggMap map[*ast.AggregateFuncExpr]int gbyCols []expression.Expression ) @@ -2741,13 +3038,12 @@ func (b *PlanBuilder) buildSelect(ctx context.Context, sel *ast.SelectStmt) (p L b.isForUpdateRead = true } - if sel.From != nil { - p, err = b.buildResultSetNode(ctx, sel.From.TableRefs) - if err != nil { - return nil, err - } - } else { - p = b.buildTableDual() + // For sub-queries, the FROM clause may have already been built in outer query when resolving correlated aggregates. + // If the ResultSetNode inside FROM clause has nothing to do with correlated aggregates, we can simply get the + // existing ResultSetNode from the cache. + p, err = b.buildTableRefsWithCache(ctx, sel.From) + if err != nil { + return nil, err } originalFields := sel.Fields.Fields @@ -2788,6 +3084,15 @@ func (b *PlanBuilder) buildSelect(ctx context.Context, sel *ast.SelectStmt) (p L return nil, err } + // We have to resolve correlated aggregate inside sub-queries before building aggregation and building projection, + // for instance, count(a) inside the sub-query of "select (select count(a)) from t" should be evaluated within + // the context of the outer query. So we have to extract such aggregates from sub-queries and put them into + // SELECT field list. + correlatedAggMap, err = b.resolveCorrelatedAggregates(ctx, sel, p) + if err != nil { + return nil, err + } + // b.allNames will be used in evalDefaultExpr(). Default function is special because it needs to find the // corresponding column name, but does not need the value in the column. // For example, select a from t order by default(b), the column b will not be in select fields. Also because @@ -2816,14 +3121,22 @@ func (b *PlanBuilder) buildSelect(ctx context.Context, sel *ast.SelectStmt) (p L hasAgg := b.detectSelectAgg(sel) if hasAgg { - aggFuncs, totalMap = b.extractAggFuncs(sel.Fields.Fields) + aggFuncs, totalMap = b.extractAggFuncsInSelectFields(sel.Fields.Fields) + // len(aggFuncs) == 0 and sel.GroupBy == nil indicates that all the aggregate functions inside the SELECT fields + // are actually correlated aggregates from the outer query, which have already been built in the outer query. + // The only thing we need to do is to find them from b.correlatedAggMap in buildProjection. + if len(aggFuncs) == 0 && sel.GroupBy == nil { + hasAgg = false + } + } + if hasAgg { var aggIndexMap map[int]int - p, aggIndexMap, err = b.buildAggregation(ctx, p, aggFuncs, gbyCols) + p, aggIndexMap, err = b.buildAggregation(ctx, p, aggFuncs, gbyCols, correlatedAggMap) if err != nil { return nil, err } - for k, v := range totalMap { - totalMap[k] = aggIndexMap[v] + for agg, idx := range totalMap { + totalMap[agg] = aggIndexMap[idx] } } diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index e67a5a2055f13..cd4c0e926efa8 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -538,7 +538,7 @@ func (s *testPlanSuite) TestColumnPruning(c *C) { ctx := context.Background() for i, tt := range input { - comment := Commentf("for %s", tt) + comment := Commentf("case:%v sql:\"%s\"", i, tt) stmt, err := s.ParseOneStmt(tt, "", "") c.Assert(err, IsNil, comment) @@ -612,26 +612,13 @@ func (s *testPlanSuite) TestAllocID(c *C) { } func (s *testPlanSuite) checkDataSourceCols(p LogicalPlan, c *C, ans map[int][]string, comment CommentInterface) { - switch p.(type) { - case *DataSource: - s.testData.OnRecord(func() { - ans[p.ID()] = make([]string, p.Schema().Len()) - }) - colList, ok := ans[p.ID()] - c.Assert(ok, IsTrue, Commentf("For %v DataSource ID %d Not found", comment, p.ID())) - c.Assert(len(p.Schema().Columns), Equals, len(colList), comment) - for i, col := range p.Schema().Columns { - s.testData.OnRecord(func() { - colList[i] = col.String() - }) - c.Assert(col.String(), Equals, colList[i], comment) - } - case *LogicalUnionAll: + switch v := p.(type) { + case *DataSource, *LogicalUnionAll: s.testData.OnRecord(func() { ans[p.ID()] = make([]string, p.Schema().Len()) }) colList, ok := ans[p.ID()] - c.Assert(ok, IsTrue, Commentf("For %v UnionAll ID %d Not found", comment, p.ID())) + c.Assert(ok, IsTrue, Commentf("For %s %T ID %d Not found", comment.CheckCommentString(), v, p.ID())) c.Assert(len(p.Schema().Columns), Equals, len(colList), comment) for i, col := range p.Schema().Columns { s.testData.OnRecord(func() { @@ -1670,6 +1657,53 @@ func (s *testPlanSuite) TestSimplyOuterJoinWithOnlyOuterExpr(c *C) { c.Assert(join.JoinType, Equals, RightOuterJoin) } +func (s *testPlanSuite) TestResolvingCorrelatedAggregate(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + sql string + best string + }{ + { + sql: "select (select count(a)) from t", + best: "Apply{DataScan(t)->Aggr(count(test.t.a))->Dual->Projection->MaxOneRow}->Projection", + }, + { + sql: "select (select count(n.a) from t) from t n", + best: "Apply{DataScan(n)->Aggr(count(test.t.a))->DataScan(t)->Projection->MaxOneRow}->Projection", + }, + { + sql: "select (select sum(count(a))) from t", + best: "Apply{DataScan(t)->Aggr(count(test.t.a))->Dual->Aggr(sum(Column#13))->MaxOneRow}->Projection", + }, + { + sql: "select (select sum(count(n.a)) from t) from t n", + best: "Apply{DataScan(n)->Aggr(count(test.t.a))->DataScan(t)->Aggr(sum(Column#25))->MaxOneRow}->Projection", + }, + { + sql: "select (select cnt from (select count(a) as cnt) n) from t", + best: "Apply{DataScan(t)->Aggr(count(test.t.a))->Dual->Projection->MaxOneRow}->Projection", + }, + { + sql: "select sum(a), sum(a), count(a), (select count(a)) from t", + best: "Apply{DataScan(t)->Aggr(sum(test.t.a),count(test.t.a))->Dual->Projection->MaxOneRow}->Projection", + }, + } + + ctx := context.TODO() + for i, tt := range tests { + comment := Commentf("case:%v sql:%s", i, tt.sql) + stmt, err := s.ParseOneStmt(tt.sql, "", "") + c.Assert(err, IsNil, comment) + err = Preprocess(s.ctx, stmt, s.is) + c.Assert(err, IsNil, comment) + p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + c.Assert(err, IsNil, comment) + p, err = logicalOptimize(context.TODO(), flagBuildKeyInfo|flagEliminateProjection|flagPrunColumns|flagPrunColumnsAgain, p.(LogicalPlan)) + c.Assert(err, IsNil, comment) + c.Assert(ToString(p), Equals, tt.best, comment) + } +} + func (s *testPlanSuite) TestFastPathInvalidBatchPointGet(c *C) { // #22040 defer testleak.AfterTest(c)() diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 0f54712613702..8761aadbdb4bf 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -458,6 +458,13 @@ type PlanBuilder struct { // It stores the OutputNames before buildProjection. allNames [][]*types.FieldName + // correlatedAggMapper stores columns for correlated aggregates which should be evaluated in outer query. + correlatedAggMapper map[*ast.AggregateFuncExpr]*expression.CorrelatedColumn + + // cache ResultSetNodes and HandleHelperMap to avoid rebuilding. + cachedResultSetNodes map[*ast.Join]LogicalPlan + cachedHandleHelperMap map[*ast.Join]map[int64][]*expression.Column + // isForUpdateRead should be true in either of the following situations // 1. use `inside insert`, `update`, `delete` or `select for update` statement // 2. isolation level is RC @@ -562,12 +569,15 @@ func NewPlanBuilder(sctx sessionctx.Context, is infoschema.InfoSchema, processor sctx.GetSessionVars().PlannerSelectBlockAsName = make([]ast.HintTable, processor.MaxSelectStmtOffset()+1) } return &PlanBuilder{ - ctx: sctx, - is: is, - colMapper: make(map[*ast.ColumnNameExpr]int), - handleHelper: &handleColHelper{id2HandleMapStack: make([]map[int64][]*expression.Column, 0)}, - hintProcessor: processor, - isForUpdateRead: sctx.GetSessionVars().IsPessimisticReadConsistency(), + ctx: sctx, + is: is, + colMapper: make(map[*ast.ColumnNameExpr]int), + handleHelper: &handleColHelper{id2HandleMapStack: make([]map[int64][]*expression.Column, 0)}, + hintProcessor: processor, + correlatedAggMapper: make(map[*ast.AggregateFuncExpr]*expression.CorrelatedColumn), + cachedResultSetNodes: make(map[*ast.Join]LogicalPlan), + cachedHandleHelperMap: make(map[*ast.Join]map[int64][]*expression.Column), + isForUpdateRead: sctx.GetSessionVars().IsPessimisticReadConsistency(), }, savedBlockNames } diff --git a/planner/core/testdata/plan_suite_out.json b/planner/core/testdata/plan_suite_out.json index 21faa82b7adc8..d3a36a913ad0c 100644 --- a/planner/core/testdata/plan_suite_out.json +++ b/planner/core/testdata/plan_suite_out.json @@ -2084,12 +2084,12 @@ { "SQL": "select 1 from (select /*+ HASH_JOIN(t1) */ t1.a in (select t2.a from t2) from t1) x;", "Plan": "LeftHashJoin{IndexReader(Index(t1.idx_a)[[NULL,+inf]])->IndexReader(Index(t2.idx_a)[[NULL,+inf]])}->Projection", - "Hints": "use_index(@`sel_2` `test`.`t1` `idx_a`), use_index(@`sel_3` `test`.`t2` `idx_a`), hash_join(@`sel_2` `test`.`t1`)" + "Hints": "use_index(@`sel_2` `test`.`t1` `idx_a`), use_index(@`sel_2` `test`.`t2` `idx_a`), hash_join(@`sel_2` `test`.`t1`)" }, { "SQL": "select 1 from (select /*+ HASH_JOIN(t1) */ t1.a not in (select t2.a from t2) from t1) x;", "Plan": "LeftHashJoin{IndexReader(Index(t1.idx_a)[[NULL,+inf]])->IndexReader(Index(t2.idx_a)[[NULL,+inf]])}->Projection", - "Hints": "use_index(@`sel_2` `test`.`t1` `idx_a`), use_index(@`sel_3` `test`.`t2` `idx_a`), hash_join(@`sel_2` `test`.`t1`)" + "Hints": "use_index(@`sel_2` `test`.`t1` `idx_a`), use_index(@`sel_2` `test`.`t2` `idx_a`), hash_join(@`sel_2` `test`.`t1`)" }, { "SQL": "select /*+ INL_JOIN(t1) */ t1.b, t2.b from t1 inner join t2 on t1.a = t2.a;", diff --git a/planner/core/testdata/plan_suite_unexported_out.json b/planner/core/testdata/plan_suite_unexported_out.json index 90047adc595a5..99c2c6a0237c6 100644 --- a/planner/core/testdata/plan_suite_unexported_out.json +++ b/planner/core/testdata/plan_suite_unexported_out.json @@ -101,7 +101,7 @@ "Cases": [ "Join{DataScan(t)->DataScan(s)}(test.t.a,test.t.a)->Projection", "Join{DataScan(t)->Aggr(count(test.t.c),firstrow(test.t.a))->DataScan(s)}(test.t.a,test.t.a)->Projection->Projection", - "Join{DataScan(t)->Aggr(count(test.t.c),firstrow(test.t.a))->DataScan(s)}(test.t.a,test.t.a)->Aggr(firstrow(Column#13),count(test.t.b))->Projection->Projection", + "Join{DataScan(t)->Aggr(count(test.t.c),firstrow(test.t.a))->DataScan(s)}(test.t.a,test.t.a)->Aggr(firstrow(Column#25),count(test.t.b))->Projection->Projection", "Apply{DataScan(t)->DataScan(s)->Sel([eq(test.t.a, test.t.a)])->Aggr(count(test.t.b))}->Projection", "Join{DataScan(t)->DataScan(s)->Aggr(count(test.t.b),firstrow(test.t.a))}(test.t.a,test.t.a)->Projection->Projection->Projection", "Join{Join{DataScan(t1)->DataScan(t2)}->DataScan(s)->Aggr(count(test.t.b),firstrow(test.t.a))}(test.t.a,test.t.a)->Projection->Projection->Projection", @@ -632,7 +632,7 @@ "1": [ "test.t.a" ], - "3": [ + "2": [ "test.t.a", "test.t.b" ] @@ -647,7 +647,7 @@ "test.t.a", "test.t.b" ], - "3": [ + "2": [ "test.t.b" ] }, @@ -655,7 +655,7 @@ "1": [ "test.t.a" ], - "3": [ + "2": [ "test.t.b" ] }, diff --git a/planner/core/util.go b/planner/core/util.go index 84b77e779d2b8..9b7c56a03eb11 100644 --- a/planner/core/util.go +++ b/planner/core/util.go @@ -27,9 +27,11 @@ import ( ) // AggregateFuncExtractor visits Expr tree. -// It converts ColunmNameExpr to AggregateFuncExpr and collects AggregateFuncExpr. +// It collects AggregateFuncExpr from AST Node. type AggregateFuncExtractor struct { - inAggregateFuncExpr bool + // skipAggMap stores correlated aggregate functions which have been built in outer query, + // so extractor in sub-query will skip these aggregate functions. + skipAggMap map[*ast.AggregateFuncExpr]*expression.CorrelatedColumn // AggFuncs is the collected AggregateFuncExprs. AggFuncs []*ast.AggregateFuncExpr } @@ -37,8 +39,6 @@ type AggregateFuncExtractor struct { // Enter implements Visitor interface. func (a *AggregateFuncExtractor) Enter(n ast.Node) (ast.Node, bool) { switch n.(type) { - case *ast.AggregateFuncExpr: - a.inAggregateFuncExpr = true case *ast.SelectStmt, *ast.UnionStmt: return n, true } @@ -49,8 +49,9 @@ func (a *AggregateFuncExtractor) Enter(n ast.Node) (ast.Node, bool) { func (a *AggregateFuncExtractor) Leave(n ast.Node) (ast.Node, bool) { switch v := n.(type) { case *ast.AggregateFuncExpr: - a.inAggregateFuncExpr = false - a.AggFuncs = append(a.AggFuncs, v) + if _, ok := a.skipAggMap[v]; !ok { + a.AggFuncs = append(a.AggFuncs, v) + } } return n, true }