Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: fix correlated aggregates which should be evaluated in outer query (#21431) #21877

Merged
merged 13 commits into from
Jan 28, 2021
Merged
Binary file removed cmd/explaintest/portgenerator
Binary file not shown.
75 changes: 72 additions & 3 deletions cmd/explaintest/r/explain_easy.result
Original file line number Diff line number Diff line change
Expand Up @@ -652,9 +652,9 @@ drop table if exists t;
create table t(a int, b int, c int);
explain select * from (select * from t order by (select 2)) t order by a, b;
id estRows task access object operator info
Sort_12 10000.00 root test.t.a:asc, test.t.b:asc
└─TableReader_18 10000.00 root data:TableFullScan_17
└─TableFullScan_17 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
Sort_13 10000.00 root test.t.a:asc, test.t.b:asc
└─TableReader_19 10000.00 root data:TableFullScan_18
└─TableFullScan_18 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
explain select * from (select * from t order by c) t order by a, b;
id estRows task access object operator info
Sort_6 10000.00 root test.t.a:asc, test.t.b:asc
Expand Down Expand Up @@ -784,3 +784,72 @@ Update_4 N/A root N/A
├─IndexRangeScan_9(Build) 0.10 cop[tikv] table:t, index:a(a, b) range:["[250 52 225 9 60 180 40 72 87 52 227 145 127 0 0 0]" "xb","[250 52 225 9 60 180 40 72 87 52 227 145 127 0 0 0]" "xb"], keep order:false, stats:pseudo
└─TableRowIDScan_10(Probe) 0.10 cop[tikv] table:t keep order:false, stats:pseudo
drop table if exists t;
create table t(a int, b int);
explain select (select count(n.a) from t) from t n;
id estRows task access object operator info
Projection_9 1.00 root Column#8
└─Apply_11 1.00 root CARTESIAN left outer join
├─StreamAgg_23(Build) 1.00 root funcs:count(Column#13)->Column#7
│ └─TableReader_24 1.00 root data:StreamAgg_15
│ └─StreamAgg_15 1.00 cop[tikv] funcs:count(test.t.a)->Column#13
│ └─TableFullScan_22 10000.00 cop[tikv] table:n keep order:false, stats:pseudo
└─MaxOneRow_27(Probe) 1.00 root
└─Projection_28 2.00 root Column#7
└─TableReader_30 2.00 root data:TableFullScan_29
└─TableFullScan_29 2.00 cop[tikv] table:t keep order:false, stats:pseudo
explain select (select sum((select count(a)))) from t;
id estRows task access object operator info
Projection_23 1.00 root Column#7
└─Apply_25 1.00 root CARTESIAN left outer join
├─StreamAgg_37(Build) 1.00 root funcs:count(Column#15)->Column#5
│ └─TableReader_38 1.00 root data:StreamAgg_29
│ └─StreamAgg_29 1.00 cop[tikv] funcs:count(test.t.a)->Column#15
│ └─TableFullScan_36 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
└─HashAgg_43(Probe) 1.00 root funcs:sum(Column#12)->Column#7
└─HashJoin_44 1.00 root CARTESIAN left outer join
├─HashAgg_49(Build) 1.00 root group by:1, funcs:sum(Column#16)->Column#12
│ └─Projection_54 1.00 root cast(Column#6, decimal(65,0) BINARY)->Column#16
│ └─MaxOneRow_50 1.00 root
│ └─Projection_51 1.00 root Column#5
│ └─TableDual_52 1.00 root rows:1
└─TableDual_46(Probe) 1.00 root rows:1
explain select count(a) from t group by b order by (select count(a));
id estRows task access object operator info
Projection_11 8000.00 root Column#4
└─Sort_12 8000.00 root Column#5:asc
└─Apply_15 8000.00 root CARTESIAN left outer join
├─HashAgg_20(Build) 8000.00 root group by:test.t.b, funcs:count(Column#9)->Column#4
│ └─TableReader_21 8000.00 root data:HashAgg_16
│ └─HashAgg_16 8000.00 cop[tikv] group by:test.t.b, funcs:count(test.t.a)->Column#9
│ └─TableFullScan_19 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
└─MaxOneRow_24(Probe) 1.00 root
└─Projection_25 1.00 root Column#4
└─TableDual_26 1.00 root rows:1
explain select (select sum(count(a))) from t;
id estRows task access object operator info
Projection_11 1.00 root Column#5
└─Apply_13 1.00 root CARTESIAN left outer join
├─StreamAgg_25(Build) 1.00 root funcs:count(Column#8)->Column#4
│ └─TableReader_26 1.00 root data:StreamAgg_17
│ └─StreamAgg_17 1.00 cop[tikv] funcs:count(test.t.a)->Column#8
│ └─TableFullScan_24 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
└─StreamAgg_32(Probe) 1.00 root funcs:sum(Column#9)->Column#5
└─Projection_39 1.00 root cast(Column#4, decimal(65,0) BINARY)->Column#9
└─TableDual_37 1.00 root rows:1
explain select sum(a), (select sum(a)), count(a) from t group by b order by (select count(a));
id estRows task access object operator info
Projection_16 8000.00 root Column#4, Column#6, Column#5
└─Sort_17 8000.00 root Column#7:asc
└─Apply_20 8000.00 root CARTESIAN left outer join
├─Apply_22(Build) 8000.00 root CARTESIAN left outer join
│ ├─HashAgg_27(Build) 8000.00 root group by:test.t.b, funcs:sum(Column#16)->Column#4, funcs:count(Column#17)->Column#5
│ │ └─TableReader_28 8000.00 root data:HashAgg_23
│ │ └─HashAgg_23 8000.00 cop[tikv] group by:test.t.b, funcs:sum(test.t.a)->Column#16, funcs:count(test.t.a)->Column#17
│ │ └─TableFullScan_26 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
│ └─MaxOneRow_31(Probe) 1.00 root
│ └─Projection_32 1.00 root Column#4
│ └─TableDual_33 1.00 root rows:1
└─MaxOneRow_34(Probe) 1.00 root
└─Projection_35 1.00 root Column#5
└─TableDual_36 1.00 root rows:1
drop table if exists t;
8 changes: 4 additions & 4 deletions cmd/explaintest/r/tpch.result
Original file line number Diff line number Diff line change
Expand Up @@ -711,10 +711,10 @@ and n_name = 'MOZAMBIQUE'
order by
value desc;
id estRows task access object operator info
Projection_52 1304801.67 root tpch.partsupp.ps_partkey, Column#18
└─Sort_53 1304801.67 root Column#18:desc
└─Selection_55 1304801.67 root gt(Column#18, NULL)
└─HashAgg_58 1631002.09 root group by:Column#44, funcs:sum(Column#42)->Column#18, funcs:firstrow(Column#43)->tpch.partsupp.ps_partkey
Projection_52 1304801.67 root tpch.partsupp.ps_partkey, Column#35
└─Sort_53 1304801.67 root Column#35:desc
└─Selection_55 1304801.67 root gt(Column#35, NULL)
└─HashAgg_58 1631002.09 root group by:Column#44, funcs:sum(Column#42)->Column#35, funcs:firstrow(Column#43)->tpch.partsupp.ps_partkey
└─Projection_79 1631002.09 root mul(tpch.partsupp.ps_supplycost, cast(tpch.partsupp.ps_availqty, decimal(20,0) BINARY))->Column#42, tpch.partsupp.ps_partkey, tpch.partsupp.ps_partkey
└─HashJoin_62 1631002.09 root inner join, equal:[eq(tpch.supplier.s_suppkey, tpch.partsupp.ps_suppkey)]
├─HashJoin_70(Build) 20000.00 root inner join, equal:[eq(tpch.nation.n_nationkey, tpch.supplier.s_nationkey)]
Expand Down
8 changes: 8 additions & 0 deletions cmd/explaintest/t/explain_easy.test
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,11 @@ create table t(a binary(16) not null, b varchar(2) default null, c varchar(100)
explain select * from t where a=x'FA34E1093CB428485734E3917F000000' and b='xb';
explain update t set c = 'ssss' where a=x'FA34E1093CB428485734E3917F000000' and b='xb';
drop table if exists t;

create table t(a int, b int);
explain select (select count(n.a) from t) from t n;
explain select (select sum((select count(a)))) from t;
explain select count(a) from t group by b order by (select count(a));
explain select (select sum(count(a))) from t;
explain select sum(a), (select sum(a)), count(a) from t group by b order by (select count(a));
drop table if exists t;
8 changes: 4 additions & 4 deletions executor/testdata/agg_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,21 @@
"Name": "TestIssue12759HashAggCalledByApply",
"Cases": [
[
"Projection_28 1.00 root Column#3, Column#6, Column#9, Column#12",
"Projection_28 1.00 root Column#9, Column#10, Column#11, Column#12",
"└─Apply_30 1.00 root CARTESIAN left outer join",
" ├─Apply_32(Build) 1.00 root CARTESIAN left outer join",
" │ ├─Apply_34(Build) 1.00 root CARTESIAN left outer join",
" │ │ ├─HashAgg_39(Build) 1.00 root funcs:sum(Column#22)->Column#3, funcs:firstrow(Column#23)->test.test.a",
" │ │ ├─HashAgg_39(Build) 1.00 root funcs:sum(Column#22)->Column#9, funcs:firstrow(Column#23)->test.test.a",
" │ │ │ └─TableReader_40 1.00 root data:HashAgg_35",
" │ │ │ └─HashAgg_35 1.00 cop[tikv] funcs:sum(test.test.a)->Column#22, funcs:firstrow(test.test.a)->Column#23",
" │ │ │ └─TableFullScan_38 10000.00 cop[tikv] table:tt keep order:false, stats:pseudo",
" │ │ └─Projection_43(Probe) 1.00 root <nil>->Column#6",
" │ │ └─Projection_43(Probe) 1.00 root <nil>->Column#10",
" │ │ └─Limit_44 1.00 root offset:0, count:1",
" │ │ └─TableReader_50 1.00 root data:Limit_49",
" │ │ └─Limit_49 1.00 cop[tikv] offset:0, count:1",
" │ │ └─Selection_48 1.00 cop[tikv] eq(test.test.a, test.test.a)",
" │ │ └─TableFullScan_47 1000.00 cop[tikv] table:test keep order:false, stats:pseudo",
" │ └─Projection_54(Probe) 1.00 root <nil>->Column#9",
" │ └─Projection_54(Probe) 1.00 root <nil>->Column#11",
" │ └─Limit_55 1.00 root offset:0, count:1",
" │ └─TableReader_61 1.00 root data:Limit_60",
" │ └─Limit_60 1.00 cop[tikv] offset:0, count:1",
Expand Down
2 changes: 1 addition & 1 deletion planner/cascades/stringer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,6 @@ func (s *testStringerSuite) TestGroupStringer(c *C) {
output[i].SQL = sql
output[i].Result = ToString(group)
})
c.Assert(ToString(group), DeepEquals, output[i].Result)
c.Assert(ToString(group), DeepEquals, output[i].Result, Commentf("case:%v, sql:%s", i, sql))
}
}
4 changes: 2 additions & 2 deletions planner/cascades/testdata/integration_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -1019,10 +1019,10 @@
{
"SQL": "select sum(a), (select t1.a from t1 where t1.a = t2.a limit 1), (select t1.b from t1 where t1.b = t2.b limit 1) from t2",
"Plan": [
"Projection_28 1.00 root Column#3, test.t1.a, test.t1.b",
"Projection_28 1.00 root Column#7, test.t1.a, test.t1.b",
"└─Apply_30 1.00 root CARTESIAN left outer join",
" ├─Apply_32(Build) 1.00 root CARTESIAN left outer join",
" │ ├─HashAgg_37(Build) 1.00 root funcs:sum(Column#8)->Column#3, funcs:firstrow(Column#9)->test.t2.a, funcs:firstrow(Column#10)->test.t2.b",
" │ ├─HashAgg_37(Build) 1.00 root funcs:sum(Column#8)->Column#7, funcs:firstrow(Column#9)->test.t2.a, funcs:firstrow(Column#10)->test.t2.b",
" │ │ └─TableReader_38 1.00 root data:HashAgg_39",
" │ │ └─HashAgg_39 1.00 cop[tikv] funcs:sum(test.t2.a)->Column#8, funcs:firstrow(test.t2.a)->Column#9, funcs:firstrow(test.t2.b)->Column#10",
" │ │ └─TableFullScan_35 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo",
Expand Down
2 changes: 1 addition & 1 deletion planner/cascades/testdata/stringer_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@
"SQL": "select a = (select a from t t2 where t1.b = t2.b order by a limit 1) from t t1",
"Result": [
"Group#0 Schema:[Column#25]",
" Projection_2 input:[Group#1], eq(test.t.a, test.t.a)->Column#25",
" Projection_3 input:[Group#1], eq(test.t.a, test.t.a)->Column#25",
"Group#1 Schema:[test.t.a,test.t.b,test.t.a]",
" Apply_9 input:[Group#2,Group#3], left outer join",
"Group#2 Schema:[test.t.a,test.t.b], UniqueKey:[test.t.a]",
Expand Down
19 changes: 16 additions & 3 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,24 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
if er.aggrMap != nil {
index, ok = er.aggrMap[v]
}
if !ok {
er.err = ErrInvalidGroupFuncUse
if ok {
// index < 0 indicates this is a correlated aggregate belonging to outer query,
// for which a correlated column will be created later, so we append a null constant
// as a temporary result expression.
if index < 0 {
er.ctxStackAppend(expression.NewNull(), types.EmptyName)
} else {
// index >= 0 indicates this is a regular aggregate column
er.ctxStackAppend(er.schema.Columns[index], er.names[index])
}
return inNode, true
}
er.ctxStackAppend(er.schema.Columns[index], er.names[index])
// replace correlated aggregate in sub-query with its corresponding correlated column
if col, ok := er.b.correlatedAggMapper[v]; ok {
er.ctxStackAppend(col, types.EmptyName)
return inNode, true
}
er.err = ErrInvalidGroupFuncUse
return inNode, true
case *ast.ColumnNameExpr:
if index, ok := er.b.colMapper[v]; ok {
Expand Down
80 changes: 80 additions & 0 deletions planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1626,6 +1626,86 @@ func (s *testIntegrationSuite) TestUpdateMultiUpdatePK(c *C) {
tk.MustQuery("SELECT * FROM t").Check(testkit.Rows("2 12"))
}

func (s *testIntegrationSuite) TestCorrelatedAggregate(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")

// #18350
tk.MustExec("DROP TABLE IF EXISTS tab, tab2")
tk.MustExec("CREATE TABLE tab(i INT)")
tk.MustExec("CREATE TABLE tab2(j INT)")
tk.MustExec("insert into tab values(1),(2),(3)")
tk.MustExec("insert into tab2 values(1),(2),(3),(15)")
tk.MustQuery(`SELECT m.i,
(SELECT COUNT(n.j)
FROM tab2 WHERE j=15) AS o
FROM tab m, tab2 n GROUP BY 1 order by m.i`).Check(testkit.Rows("1 4", "2 4", "3 4"))
tk.MustQuery(`SELECT
(SELECT COUNT(n.j)
FROM tab2 WHERE j=15) AS o
FROM tab m, tab2 n order by m.i`).Check(testkit.Rows("12"))

// #17748
tk.MustExec("drop table if exists t1, t2")
tk.MustExec("create table t1 (a int, b int)")
tk.MustExec("create table t2 (m int, n int)")
tk.MustExec("insert into t1 values (2,2), (2,2), (3,3), (3,3), (3,3), (4,4)")
tk.MustExec("insert into t2 values (1,11), (2,22), (3,32), (4,44), (4,44)")
tk.MustExec("set @@sql_mode='TRADITIONAL'")

tk.MustQuery(`select count(*) c, a,
( select group_concat(count(a)) from t2 where m = a )
from t1 group by a order by a`).
Check(testkit.Rows("2 2 2", "3 3 3", "1 4 1,1"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a int, b int)")
tk.MustExec("insert into t values (1,1),(2,1),(2,2),(3,1),(3,2),(3,3)")

// Sub-queries in SELECT fields
// from SELECT fields
tk.MustQuery("select (select count(a)) from t").Check(testkit.Rows("6"))
tk.MustQuery("select (select (select (select count(a)))) from t").Check(testkit.Rows("6"))
tk.MustQuery("select (select (select count(n.a)) from t m order by count(m.b)) from t n").Check(testkit.Rows("6"))
// from WHERE
tk.MustQuery("select (select count(n.a) from t where count(n.a)=3) from t n").Check(testkit.Rows("<nil>"))
tk.MustQuery("select (select count(a) from t where count(distinct n.a)=3) from t n").Check(testkit.Rows("6"))
// from HAVING
tk.MustQuery("select (select count(n.a) from t having count(n.a)=6 limit 1) from t n").Check(testkit.Rows("6"))
tk.MustQuery("select (select count(n.a) from t having count(distinct n.b)=3 limit 1) from t n").Check(testkit.Rows("6"))
tk.MustQuery("select (select sum(distinct n.a) from t having count(distinct n.b)=3 limit 1) from t n").Check(testkit.Rows("6"))
tk.MustQuery("select (select sum(distinct n.a) from t having count(distinct n.b)=6 limit 1) from t n").Check(testkit.Rows("<nil>"))
// from ORDER BY
tk.MustQuery("select (select count(n.a) from t order by count(n.b) limit 1) from t n").Check(testkit.Rows("6"))
tk.MustQuery("select (select count(distinct n.b) from t order by count(n.b) limit 1) from t n").Check(testkit.Rows("3"))
// from TableRefsClause
tk.MustQuery("select (select cnt from (select count(a) cnt) s) from t").Check(testkit.Rows("6"))
tk.MustQuery("select (select count(cnt) from (select count(a) cnt) s) from t").Check(testkit.Rows("1"))
// from sub-query inside aggregate
tk.MustQuery("select (select sum((select count(a)))) from t").Check(testkit.Rows("6"))
tk.MustQuery("select (select sum((select count(a))+sum(a))) from t").Check(testkit.Rows("20"))
// from GROUP BY
tk.MustQuery("select (select count(a) from t group by count(n.a)) from t n").Check(testkit.Rows("6"))
tk.MustQuery("select (select count(distinct a) from t group by count(n.a)) from t n").Check(testkit.Rows("3"))

// Sub-queries in HAVING
tk.MustQuery("select sum(a) from t having (select count(a)) = 0").Check(testkit.Rows())
tk.MustQuery("select sum(a) from t having (select count(a)) > 0").Check(testkit.Rows("14"))

// Sub-queries in ORDER BY
tk.MustQuery("select count(a) from t group by b order by (select count(a))").Check(testkit.Rows("1", "2", "3"))
tk.MustQuery("select count(a) from t group by b order by (select -count(a))").Check(testkit.Rows("3", "2", "1"))

// Nested aggregate (correlated aggregate inside aggregate)
tk.MustQuery("select (select sum(count(a))) from t").Check(testkit.Rows("6"))
tk.MustQuery("select (select sum(sum(a))) from t").Check(testkit.Rows("14"))

// Combining aggregates
tk.MustQuery("select count(a), (select count(a)) from t").Check(testkit.Rows("6 6"))
tk.MustQuery("select sum(distinct b), count(a), (select count(a)), (select cnt from (select sum(distinct b) as cnt) n) from t").
Check(testkit.Rows("6 6 6 6"))
}

func (s *testIntegrationSuite) TestInvalidNamedWindowSpec(c *C) {
// #12356
tk := testkit.NewTestKit(c, s.store)
Expand Down
Loading