Skip to content

Commit

Permalink
planner: fix correlated aggregates which should be evaluated in outer…
Browse files Browse the repository at this point in the history
… query (#21431) (#21877)
  • Loading branch information
ti-srebot authored Jan 28, 2021
1 parent 8e54456 commit 80d420d
Show file tree
Hide file tree
Showing 16 changed files with 611 additions and 83 deletions.
Binary file removed cmd/explaintest/portgenerator
Binary file not shown.
75 changes: 72 additions & 3 deletions cmd/explaintest/r/explain_easy.result
Original file line number Diff line number Diff line change
Expand Up @@ -652,9 +652,9 @@ drop table if exists t;
create table t(a int, b int, c int);
explain select * from (select * from t order by (select 2)) t order by a, b;
id estRows task access object operator info
Sort_12 10000.00 root test.t.a:asc, test.t.b:asc
└─TableReader_18 10000.00 root data:TableFullScan_17
└─TableFullScan_17 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
Sort_13 10000.00 root test.t.a:asc, test.t.b:asc
└─TableReader_19 10000.00 root data:TableFullScan_18
└─TableFullScan_18 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
explain select * from (select * from t order by c) t order by a, b;
id estRows task access object operator info
Sort_6 10000.00 root test.t.a:asc, test.t.b:asc
Expand Down Expand Up @@ -784,3 +784,72 @@ Update_4 N/A root N/A
├─IndexRangeScan_9(Build) 0.10 cop[tikv] table:t, index:a(a, b) range:["[250 52 225 9 60 180 40 72 87 52 227 145 127 0 0 0]" "xb","[250 52 225 9 60 180 40 72 87 52 227 145 127 0 0 0]" "xb"], keep order:false, stats:pseudo
└─TableRowIDScan_10(Probe) 0.10 cop[tikv] table:t keep order:false, stats:pseudo
drop table if exists t;
create table t(a int, b int);
explain select (select count(n.a) from t) from t n;
id estRows task access object operator info
Projection_9 1.00 root Column#8
└─Apply_11 1.00 root CARTESIAN left outer join
├─StreamAgg_23(Build) 1.00 root funcs:count(Column#13)->Column#7
│ └─TableReader_24 1.00 root data:StreamAgg_15
│ └─StreamAgg_15 1.00 cop[tikv] funcs:count(test.t.a)->Column#13
│ └─TableFullScan_22 10000.00 cop[tikv] table:n keep order:false, stats:pseudo
└─MaxOneRow_27(Probe) 1.00 root
└─Projection_28 2.00 root Column#7
└─TableReader_30 2.00 root data:TableFullScan_29
└─TableFullScan_29 2.00 cop[tikv] table:t keep order:false, stats:pseudo
explain select (select sum((select count(a)))) from t;
id estRows task access object operator info
Projection_23 1.00 root Column#7
└─Apply_25 1.00 root CARTESIAN left outer join
├─StreamAgg_37(Build) 1.00 root funcs:count(Column#15)->Column#5
│ └─TableReader_38 1.00 root data:StreamAgg_29
│ └─StreamAgg_29 1.00 cop[tikv] funcs:count(test.t.a)->Column#15
│ └─TableFullScan_36 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
└─HashAgg_43(Probe) 1.00 root funcs:sum(Column#12)->Column#7
└─HashJoin_44 1.00 root CARTESIAN left outer join
├─HashAgg_49(Build) 1.00 root group by:1, funcs:sum(Column#16)->Column#12
│ └─Projection_54 1.00 root cast(Column#6, decimal(65,0) BINARY)->Column#16
│ └─MaxOneRow_50 1.00 root
│ └─Projection_51 1.00 root Column#5
│ └─TableDual_52 1.00 root rows:1
└─TableDual_46(Probe) 1.00 root rows:1
explain select count(a) from t group by b order by (select count(a));
id estRows task access object operator info
Projection_11 8000.00 root Column#4
└─Sort_12 8000.00 root Column#5:asc
└─Apply_15 8000.00 root CARTESIAN left outer join
├─HashAgg_20(Build) 8000.00 root group by:test.t.b, funcs:count(Column#9)->Column#4
│ └─TableReader_21 8000.00 root data:HashAgg_16
│ └─HashAgg_16 8000.00 cop[tikv] group by:test.t.b, funcs:count(test.t.a)->Column#9
│ └─TableFullScan_19 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
└─MaxOneRow_24(Probe) 1.00 root
└─Projection_25 1.00 root Column#4
└─TableDual_26 1.00 root rows:1
explain select (select sum(count(a))) from t;
id estRows task access object operator info
Projection_11 1.00 root Column#5
└─Apply_13 1.00 root CARTESIAN left outer join
├─StreamAgg_25(Build) 1.00 root funcs:count(Column#8)->Column#4
│ └─TableReader_26 1.00 root data:StreamAgg_17
│ └─StreamAgg_17 1.00 cop[tikv] funcs:count(test.t.a)->Column#8
│ └─TableFullScan_24 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
└─StreamAgg_32(Probe) 1.00 root funcs:sum(Column#9)->Column#5
└─Projection_39 1.00 root cast(Column#4, decimal(65,0) BINARY)->Column#9
└─TableDual_37 1.00 root rows:1
explain select sum(a), (select sum(a)), count(a) from t group by b order by (select count(a));
id estRows task access object operator info
Projection_16 8000.00 root Column#4, Column#6, Column#5
└─Sort_17 8000.00 root Column#7:asc
└─Apply_20 8000.00 root CARTESIAN left outer join
├─Apply_22(Build) 8000.00 root CARTESIAN left outer join
│ ├─HashAgg_27(Build) 8000.00 root group by:test.t.b, funcs:sum(Column#16)->Column#4, funcs:count(Column#17)->Column#5
│ │ └─TableReader_28 8000.00 root data:HashAgg_23
│ │ └─HashAgg_23 8000.00 cop[tikv] group by:test.t.b, funcs:sum(test.t.a)->Column#16, funcs:count(test.t.a)->Column#17
│ │ └─TableFullScan_26 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
│ └─MaxOneRow_31(Probe) 1.00 root
│ └─Projection_32 1.00 root Column#4
│ └─TableDual_33 1.00 root rows:1
└─MaxOneRow_34(Probe) 1.00 root
└─Projection_35 1.00 root Column#5
└─TableDual_36 1.00 root rows:1
drop table if exists t;
8 changes: 4 additions & 4 deletions cmd/explaintest/r/tpch.result
Original file line number Diff line number Diff line change
Expand Up @@ -711,10 +711,10 @@ and n_name = 'MOZAMBIQUE'
order by
value desc;
id estRows task access object operator info
Projection_52 1304801.67 root tpch.partsupp.ps_partkey, Column#18
└─Sort_53 1304801.67 root Column#18:desc
└─Selection_55 1304801.67 root gt(Column#18, NULL)
└─HashAgg_58 1631002.09 root group by:Column#44, funcs:sum(Column#42)->Column#18, funcs:firstrow(Column#43)->tpch.partsupp.ps_partkey
Projection_52 1304801.67 root tpch.partsupp.ps_partkey, Column#35
└─Sort_53 1304801.67 root Column#35:desc
└─Selection_55 1304801.67 root gt(Column#35, NULL)
└─HashAgg_58 1631002.09 root group by:Column#44, funcs:sum(Column#42)->Column#35, funcs:firstrow(Column#43)->tpch.partsupp.ps_partkey
└─Projection_79 1631002.09 root mul(tpch.partsupp.ps_supplycost, cast(tpch.partsupp.ps_availqty, decimal(20,0) BINARY))->Column#42, tpch.partsupp.ps_partkey, tpch.partsupp.ps_partkey
└─HashJoin_62 1631002.09 root inner join, equal:[eq(tpch.supplier.s_suppkey, tpch.partsupp.ps_suppkey)]
├─HashJoin_70(Build) 20000.00 root inner join, equal:[eq(tpch.nation.n_nationkey, tpch.supplier.s_nationkey)]
Expand Down
8 changes: 8 additions & 0 deletions cmd/explaintest/t/explain_easy.test
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,11 @@ create table t(a binary(16) not null, b varchar(2) default null, c varchar(100)
explain select * from t where a=x'FA34E1093CB428485734E3917F000000' and b='xb';
explain update t set c = 'ssss' where a=x'FA34E1093CB428485734E3917F000000' and b='xb';
drop table if exists t;

create table t(a int, b int);
explain select (select count(n.a) from t) from t n;
explain select (select sum((select count(a)))) from t;
explain select count(a) from t group by b order by (select count(a));
explain select (select sum(count(a))) from t;
explain select sum(a), (select sum(a)), count(a) from t group by b order by (select count(a));
drop table if exists t;
8 changes: 4 additions & 4 deletions executor/testdata/agg_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,21 @@
"Name": "TestIssue12759HashAggCalledByApply",
"Cases": [
[
"Projection_28 1.00 root Column#3, Column#6, Column#9, Column#12",
"Projection_28 1.00 root Column#9, Column#10, Column#11, Column#12",
"└─Apply_30 1.00 root CARTESIAN left outer join",
" ├─Apply_32(Build) 1.00 root CARTESIAN left outer join",
" │ ├─Apply_34(Build) 1.00 root CARTESIAN left outer join",
" │ │ ├─HashAgg_39(Build) 1.00 root funcs:sum(Column#22)->Column#3, funcs:firstrow(Column#23)->test.test.a",
" │ │ ├─HashAgg_39(Build) 1.00 root funcs:sum(Column#22)->Column#9, funcs:firstrow(Column#23)->test.test.a",
" │ │ │ └─TableReader_40 1.00 root data:HashAgg_35",
" │ │ │ └─HashAgg_35 1.00 cop[tikv] funcs:sum(test.test.a)->Column#22, funcs:firstrow(test.test.a)->Column#23",
" │ │ │ └─TableFullScan_38 10000.00 cop[tikv] table:tt keep order:false, stats:pseudo",
" │ │ └─Projection_43(Probe) 1.00 root <nil>->Column#6",
" │ │ └─Projection_43(Probe) 1.00 root <nil>->Column#10",
" │ │ └─Limit_44 1.00 root offset:0, count:1",
" │ │ └─TableReader_50 1.00 root data:Limit_49",
" │ │ └─Limit_49 1.00 cop[tikv] offset:0, count:1",
" │ │ └─Selection_48 1.00 cop[tikv] eq(test.test.a, test.test.a)",
" │ │ └─TableFullScan_47 1000.00 cop[tikv] table:test keep order:false, stats:pseudo",
" │ └─Projection_54(Probe) 1.00 root <nil>->Column#9",
" │ └─Projection_54(Probe) 1.00 root <nil>->Column#11",
" │ └─Limit_55 1.00 root offset:0, count:1",
" │ └─TableReader_61 1.00 root data:Limit_60",
" │ └─Limit_60 1.00 cop[tikv] offset:0, count:1",
Expand Down
2 changes: 1 addition & 1 deletion planner/cascades/stringer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,6 @@ func (s *testStringerSuite) TestGroupStringer(c *C) {
output[i].SQL = sql
output[i].Result = ToString(group)
})
c.Assert(ToString(group), DeepEquals, output[i].Result)
c.Assert(ToString(group), DeepEquals, output[i].Result, Commentf("case:%v, sql:%s", i, sql))
}
}
4 changes: 2 additions & 2 deletions planner/cascades/testdata/integration_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -1019,10 +1019,10 @@
{
"SQL": "select sum(a), (select t1.a from t1 where t1.a = t2.a limit 1), (select t1.b from t1 where t1.b = t2.b limit 1) from t2",
"Plan": [
"Projection_28 1.00 root Column#3, test.t1.a, test.t1.b",
"Projection_28 1.00 root Column#7, test.t1.a, test.t1.b",
"└─Apply_30 1.00 root CARTESIAN left outer join",
" ├─Apply_32(Build) 1.00 root CARTESIAN left outer join",
" │ ├─HashAgg_37(Build) 1.00 root funcs:sum(Column#8)->Column#3, funcs:firstrow(Column#9)->test.t2.a, funcs:firstrow(Column#10)->test.t2.b",
" │ ├─HashAgg_37(Build) 1.00 root funcs:sum(Column#8)->Column#7, funcs:firstrow(Column#9)->test.t2.a, funcs:firstrow(Column#10)->test.t2.b",
" │ │ └─TableReader_38 1.00 root data:HashAgg_39",
" │ │ └─HashAgg_39 1.00 cop[tikv] funcs:sum(test.t2.a)->Column#8, funcs:firstrow(test.t2.a)->Column#9, funcs:firstrow(test.t2.b)->Column#10",
" │ │ └─TableFullScan_35 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo",
Expand Down
2 changes: 1 addition & 1 deletion planner/cascades/testdata/stringer_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@
"SQL": "select a = (select a from t t2 where t1.b = t2.b order by a limit 1) from t t1",
"Result": [
"Group#0 Schema:[Column#25]",
" Projection_2 input:[Group#1], eq(test.t.a, test.t.a)->Column#25",
" Projection_3 input:[Group#1], eq(test.t.a, test.t.a)->Column#25",
"Group#1 Schema:[test.t.a,test.t.b,test.t.a]",
" Apply_9 input:[Group#2,Group#3], left outer join",
"Group#2 Schema:[test.t.a,test.t.b], UniqueKey:[test.t.a]",
Expand Down
19 changes: 16 additions & 3 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,24 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
if er.aggrMap != nil {
index, ok = er.aggrMap[v]
}
if !ok {
er.err = ErrInvalidGroupFuncUse
if ok {
// index < 0 indicates this is a correlated aggregate belonging to outer query,
// for which a correlated column will be created later, so we append a null constant
// as a temporary result expression.
if index < 0 {
er.ctxStackAppend(expression.NewNull(), types.EmptyName)
} else {
// index >= 0 indicates this is a regular aggregate column
er.ctxStackAppend(er.schema.Columns[index], er.names[index])
}
return inNode, true
}
er.ctxStackAppend(er.schema.Columns[index], er.names[index])
// replace correlated aggregate in sub-query with its corresponding correlated column
if col, ok := er.b.correlatedAggMapper[v]; ok {
er.ctxStackAppend(col, types.EmptyName)
return inNode, true
}
er.err = ErrInvalidGroupFuncUse
return inNode, true
case *ast.ColumnNameExpr:
if index, ok := er.b.colMapper[v]; ok {
Expand Down
80 changes: 80 additions & 0 deletions planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1626,6 +1626,86 @@ func (s *testIntegrationSuite) TestUpdateMultiUpdatePK(c *C) {
tk.MustQuery("SELECT * FROM t").Check(testkit.Rows("2 12"))
}

func (s *testIntegrationSuite) TestCorrelatedAggregate(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")

// #18350
tk.MustExec("DROP TABLE IF EXISTS tab, tab2")
tk.MustExec("CREATE TABLE tab(i INT)")
tk.MustExec("CREATE TABLE tab2(j INT)")
tk.MustExec("insert into tab values(1),(2),(3)")
tk.MustExec("insert into tab2 values(1),(2),(3),(15)")
tk.MustQuery(`SELECT m.i,
(SELECT COUNT(n.j)
FROM tab2 WHERE j=15) AS o
FROM tab m, tab2 n GROUP BY 1 order by m.i`).Check(testkit.Rows("1 4", "2 4", "3 4"))
tk.MustQuery(`SELECT
(SELECT COUNT(n.j)
FROM tab2 WHERE j=15) AS o
FROM tab m, tab2 n order by m.i`).Check(testkit.Rows("12"))

// #17748
tk.MustExec("drop table if exists t1, t2")
tk.MustExec("create table t1 (a int, b int)")
tk.MustExec("create table t2 (m int, n int)")
tk.MustExec("insert into t1 values (2,2), (2,2), (3,3), (3,3), (3,3), (4,4)")
tk.MustExec("insert into t2 values (1,11), (2,22), (3,32), (4,44), (4,44)")
tk.MustExec("set @@sql_mode='TRADITIONAL'")

tk.MustQuery(`select count(*) c, a,
( select group_concat(count(a)) from t2 where m = a )
from t1 group by a order by a`).
Check(testkit.Rows("2 2 2", "3 3 3", "1 4 1,1"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a int, b int)")
tk.MustExec("insert into t values (1,1),(2,1),(2,2),(3,1),(3,2),(3,3)")

// Sub-queries in SELECT fields
// from SELECT fields
tk.MustQuery("select (select count(a)) from t").Check(testkit.Rows("6"))
tk.MustQuery("select (select (select (select count(a)))) from t").Check(testkit.Rows("6"))
tk.MustQuery("select (select (select count(n.a)) from t m order by count(m.b)) from t n").Check(testkit.Rows("6"))
// from WHERE
tk.MustQuery("select (select count(n.a) from t where count(n.a)=3) from t n").Check(testkit.Rows("<nil>"))
tk.MustQuery("select (select count(a) from t where count(distinct n.a)=3) from t n").Check(testkit.Rows("6"))
// from HAVING
tk.MustQuery("select (select count(n.a) from t having count(n.a)=6 limit 1) from t n").Check(testkit.Rows("6"))
tk.MustQuery("select (select count(n.a) from t having count(distinct n.b)=3 limit 1) from t n").Check(testkit.Rows("6"))
tk.MustQuery("select (select sum(distinct n.a) from t having count(distinct n.b)=3 limit 1) from t n").Check(testkit.Rows("6"))
tk.MustQuery("select (select sum(distinct n.a) from t having count(distinct n.b)=6 limit 1) from t n").Check(testkit.Rows("<nil>"))
// from ORDER BY
tk.MustQuery("select (select count(n.a) from t order by count(n.b) limit 1) from t n").Check(testkit.Rows("6"))
tk.MustQuery("select (select count(distinct n.b) from t order by count(n.b) limit 1) from t n").Check(testkit.Rows("3"))
// from TableRefsClause
tk.MustQuery("select (select cnt from (select count(a) cnt) s) from t").Check(testkit.Rows("6"))
tk.MustQuery("select (select count(cnt) from (select count(a) cnt) s) from t").Check(testkit.Rows("1"))
// from sub-query inside aggregate
tk.MustQuery("select (select sum((select count(a)))) from t").Check(testkit.Rows("6"))
tk.MustQuery("select (select sum((select count(a))+sum(a))) from t").Check(testkit.Rows("20"))
// from GROUP BY
tk.MustQuery("select (select count(a) from t group by count(n.a)) from t n").Check(testkit.Rows("6"))
tk.MustQuery("select (select count(distinct a) from t group by count(n.a)) from t n").Check(testkit.Rows("3"))

// Sub-queries in HAVING
tk.MustQuery("select sum(a) from t having (select count(a)) = 0").Check(testkit.Rows())
tk.MustQuery("select sum(a) from t having (select count(a)) > 0").Check(testkit.Rows("14"))

// Sub-queries in ORDER BY
tk.MustQuery("select count(a) from t group by b order by (select count(a))").Check(testkit.Rows("1", "2", "3"))
tk.MustQuery("select count(a) from t group by b order by (select -count(a))").Check(testkit.Rows("3", "2", "1"))

// Nested aggregate (correlated aggregate inside aggregate)
tk.MustQuery("select (select sum(count(a))) from t").Check(testkit.Rows("6"))
tk.MustQuery("select (select sum(sum(a))) from t").Check(testkit.Rows("14"))

// Combining aggregates
tk.MustQuery("select count(a), (select count(a)) from t").Check(testkit.Rows("6 6"))
tk.MustQuery("select sum(distinct b), count(a), (select count(a)), (select cnt from (select sum(distinct b) as cnt) n) from t").
Check(testkit.Rows("6 6 6 6"))
}

func (s *testIntegrationSuite) TestInvalidNamedWindowSpec(c *C) {
// #12356
tk := testkit.NewTestKit(c, s.store)
Expand Down
Loading

0 comments on commit 80d420d

Please sign in to comment.