diff --git a/cmd/explaintest/r/explain_easy.result b/cmd/explaintest/r/explain_easy.result index f2970e9916ab2..c9dac04b40f74 100644 --- a/cmd/explaintest/r/explain_easy.result +++ b/cmd/explaintest/r/explain_easy.result @@ -652,9 +652,15 @@ drop table if exists t; create table t(a int, b int, c int); explain select * from (select * from t order by (select 2)) t order by a, b; id estRows task access object operator info +<<<<<<< HEAD Sort_12 10000.00 root test.t.a:asc, test.t.b:asc └─TableReader_18 10000.00 root data:TableFullScan_17 └─TableFullScan_17 10000.00 cop[tikv] table:t keep order:false, stats:pseudo +======= +Sort_13 10000.00 root test.t.a, test.t.b +└─TableReader_19 10000.00 root data:TableFullScan_18 + └─TableFullScan_18 10000.00 cop[tikv] table:t keep order:false, stats:pseudo +>>>>>>> f687ebd91... planner: fix correlated aggregates which should be evaluated in outer query (#21431) explain select * from (select * from t order by c) t order by a, b; id estRows task access object operator info Sort_6 10000.00 root test.t.a:asc, test.t.b:asc @@ -784,3 +790,65 @@ Update_4 N/A root N/A ├─IndexRangeScan_9(Build) 0.10 cop[tikv] table:t, index:a(a, b) range:["[250 52 225 9 60 180 40 72 87 52 227 145 127 0 0 0]" "xb","[250 52 225 9 60 180 40 72 87 52 227 145 127 0 0 0]" "xb"], keep order:false, stats:pseudo └─TableRowIDScan_10(Probe) 0.10 cop[tikv] table:t keep order:false, stats:pseudo drop table if exists t; +create table t(a int, b int); +explain select (select count(n.a) from t) from t n; +id estRows task access object operator info +Projection_9 1.00 root Column#8 +└─Apply_11 1.00 root CARTESIAN left outer join + ├─StreamAgg_23(Build) 1.00 root funcs:count(Column#13)->Column#7 + │ └─TableReader_24 1.00 root data:StreamAgg_15 + │ └─StreamAgg_15 1.00 cop[tikv] funcs:count(test.t.a)->Column#13 + │ └─TableFullScan_22 10000.00 cop[tikv] table:n keep order:false, stats:pseudo + └─MaxOneRow_27(Probe) 1.00 root + └─Projection_28 2.00 root Column#7 + └─TableReader_30 2.00 root data:TableFullScan_29 + └─TableFullScan_29 2.00 cop[tikv] table:t keep order:false, stats:pseudo +explain select (select sum((select count(a)))) from t; +id estRows task access object operator info +Projection_23 1.00 root Column#7 +└─Apply_25 1.00 root CARTESIAN left outer join + ├─StreamAgg_37(Build) 1.00 root funcs:count(Column#15)->Column#5 + │ └─TableReader_38 1.00 root data:StreamAgg_29 + │ └─StreamAgg_29 1.00 cop[tikv] funcs:count(test.t.a)->Column#15 + │ └─TableFullScan_36 10000.00 cop[tikv] table:t keep order:false, stats:pseudo + └─HashAgg_43(Probe) 1.00 root funcs:sum(Column#12)->Column#7 + └─HashJoin_44 1.00 root CARTESIAN left outer join + ├─HashAgg_49(Build) 1.00 root group by:1, funcs:sum(Column#16)->Column#12 + │ └─Projection_54 1.00 root cast(Column#6, decimal(42,0) BINARY)->Column#16 + │ └─MaxOneRow_50 1.00 root + │ └─Projection_51 1.00 root Column#5 + │ └─TableDual_52 1.00 root rows:1 + └─TableDual_46(Probe) 1.00 root rows:1 +explain select count(a) from t group by b order by (select count(a)); +id estRows task access object operator info +Sort_12 8000.00 root Column#4 +└─HashJoin_14 8000.00 root CARTESIAN left outer join + ├─TableDual_24(Build) 1.00 root rows:1 + └─HashAgg_20(Probe) 8000.00 root group by:test.t.b, funcs:count(Column#8)->Column#4 + └─TableReader_21 8000.00 root data:HashAgg_16 + └─HashAgg_16 8000.00 cop[tikv] group by:test.t.b, funcs:count(test.t.a)->Column#8 + └─TableFullScan_19 10000.00 cop[tikv] table:t keep order:false, stats:pseudo +explain select (select sum(count(a))) from t; +id estRows task access object operator info +Projection_11 1.00 root Column#5 +└─Apply_13 1.00 root CARTESIAN left outer join + ├─StreamAgg_25(Build) 1.00 root funcs:count(Column#8)->Column#4 + │ └─TableReader_26 1.00 root data:StreamAgg_17 + │ └─StreamAgg_17 1.00 cop[tikv] funcs:count(test.t.a)->Column#8 + │ └─TableFullScan_24 10000.00 cop[tikv] table:t keep order:false, stats:pseudo + └─StreamAgg_32(Probe) 1.00 root funcs:sum(Column#9)->Column#5 + └─Projection_39 1.00 root cast(Column#4, decimal(42,0) BINARY)->Column#9 + └─TableDual_37 1.00 root rows:1 +explain select sum(a), (select sum(a)), count(a) from t group by b order by (select count(a)); +id estRows task access object operator info +Projection_16 8000.00 root Column#4, Column#4, Column#5 +└─Sort_17 8000.00 root Column#5 + └─HashJoin_19 8000.00 root CARTESIAN left outer join + ├─TableDual_33(Build) 1.00 root rows:1 + └─HashJoin_21(Probe) 8000.00 root CARTESIAN left outer join + ├─TableDual_31(Build) 1.00 root rows:1 + └─HashAgg_27(Probe) 8000.00 root group by:test.t.b, funcs:sum(Column#13)->Column#4, funcs:count(Column#14)->Column#5 + └─TableReader_28 8000.00 root data:HashAgg_23 + └─HashAgg_23 8000.00 cop[tikv] group by:test.t.b, funcs:sum(test.t.a)->Column#13, funcs:count(test.t.a)->Column#14 + └─TableFullScan_26 10000.00 cop[tikv] table:t keep order:false, stats:pseudo +drop table if exists t; diff --git a/cmd/explaintest/r/tpch.result b/cmd/explaintest/r/tpch.result index 231236d380b90..8751e16e815c2 100644 --- a/cmd/explaintest/r/tpch.result +++ b/cmd/explaintest/r/tpch.result @@ -711,6 +711,7 @@ and n_name = 'MOZAMBIQUE' order by value desc; id estRows task access object operator info +<<<<<<< HEAD Projection_52 1304801.67 root tpch.partsupp.ps_partkey, Column#18 └─Sort_53 1304801.67 root Column#18:desc └─Selection_55 1304801.67 root gt(Column#18, NULL) @@ -725,6 +726,22 @@ Projection_52 1304801.67 root tpch.partsupp.ps_partkey, Column#18 │ └─TableFullScan_71 500000.00 cop[tikv] table:supplier keep order:false └─TableReader_77(Probe) 40000000.00 root data:TableFullScan_76 └─TableFullScan_76 40000000.00 cop[tikv] table:partsupp keep order:false +======= +Projection_57 1304801.67 root tpch.partsupp.ps_partkey, Column#35 +└─Sort_58 1304801.67 root Column#35:desc + └─Selection_60 1304801.67 root gt(Column#35, NULL) + └─HashAgg_63 1631002.09 root group by:Column#44, funcs:sum(Column#42)->Column#35, funcs:firstrow(Column#43)->tpch.partsupp.ps_partkey + └─Projection_89 1631002.09 root mul(tpch.partsupp.ps_supplycost, cast(tpch.partsupp.ps_availqty, decimal(20,0) BINARY))->Column#42, tpch.partsupp.ps_partkey, tpch.partsupp.ps_partkey + └─HashJoin_67 1631002.09 root inner join, equal:[eq(tpch.supplier.s_suppkey, tpch.partsupp.ps_suppkey)] + ├─HashJoin_80(Build) 20000.00 root inner join, equal:[eq(tpch.nation.n_nationkey, tpch.supplier.s_nationkey)] + │ ├─TableReader_85(Build) 1.00 root data:Selection_84 + │ │ └─Selection_84 1.00 cop[tikv] eq(tpch.nation.n_name, "MOZAMBIQUE") + │ │ └─TableFullScan_83 25.00 cop[tikv] table:nation keep order:false + │ └─TableReader_82(Probe) 500000.00 root data:TableFullScan_81 + │ └─TableFullScan_81 500000.00 cop[tikv] table:supplier keep order:false + └─TableReader_87(Probe) 40000000.00 root data:TableFullScan_86 + └─TableFullScan_86 40000000.00 cop[tikv] table:partsupp keep order:false +>>>>>>> f687ebd91... planner: fix correlated aggregates which should be evaluated in outer query (#21431) /* Q12 Shipping Modes and Order Priority Query This query determines whether selecting less expensive modes of shipping is negatively affecting the critical-priority diff --git a/cmd/explaintest/t/explain_easy.test b/cmd/explaintest/t/explain_easy.test index 930f097d9d2b3..a01179dd3ddbe 100644 --- a/cmd/explaintest/t/explain_easy.test +++ b/cmd/explaintest/t/explain_easy.test @@ -210,3 +210,11 @@ create table t(a binary(16) not null, b varchar(2) default null, c varchar(100) explain select * from t where a=x'FA34E1093CB428485734E3917F000000' and b='xb'; explain update t set c = 'ssss' where a=x'FA34E1093CB428485734E3917F000000' and b='xb'; drop table if exists t; + +create table t(a int, b int); +explain select (select count(n.a) from t) from t n; +explain select (select sum((select count(a)))) from t; +explain select count(a) from t group by b order by (select count(a)); +explain select (select sum(count(a))) from t; +explain select sum(a), (select sum(a)), count(a) from t group by b order by (select count(a)); +drop table if exists t; diff --git a/executor/testdata/agg_suite_out.json b/executor/testdata/agg_suite_out.json index ab65db16fb991..73d920dbbbb3c 100644 --- a/executor/testdata/agg_suite_out.json +++ b/executor/testdata/agg_suite_out.json @@ -49,21 +49,21 @@ "Name": "TestIssue12759HashAggCalledByApply", "Cases": [ [ - "Projection_28 1.00 root Column#3, Column#6, Column#9, Column#12", + "Projection_28 1.00 root Column#9, Column#10, Column#11, Column#12", "└─Apply_30 1.00 root CARTESIAN left outer join", " ├─Apply_32(Build) 1.00 root CARTESIAN left outer join", " │ ├─Apply_34(Build) 1.00 root CARTESIAN left outer join", - " │ │ ├─HashAgg_39(Build) 1.00 root funcs:sum(Column#22)->Column#3, funcs:firstrow(Column#23)->test.test.a", + " │ │ ├─HashAgg_39(Build) 1.00 root funcs:sum(Column#22)->Column#9, funcs:firstrow(Column#23)->test.test.a", " │ │ │ └─TableReader_40 1.00 root data:HashAgg_35", " │ │ │ └─HashAgg_35 1.00 cop[tikv] funcs:sum(test.test.a)->Column#22, funcs:firstrow(test.test.a)->Column#23", " │ │ │ └─TableFullScan_38 10000.00 cop[tikv] table:tt keep order:false, stats:pseudo", - " │ │ └─Projection_43(Probe) 1.00 root ->Column#6", + " │ │ └─Projection_43(Probe) 1.00 root ->Column#10", " │ │ └─Limit_44 1.00 root offset:0, count:1", " │ │ └─TableReader_50 1.00 root data:Limit_49", " │ │ └─Limit_49 1.00 cop[tikv] offset:0, count:1", " │ │ └─Selection_48 1.00 cop[tikv] eq(test.test.a, test.test.a)", " │ │ └─TableFullScan_47 1000.00 cop[tikv] table:test keep order:false, stats:pseudo", - " │ └─Projection_54(Probe) 1.00 root ->Column#9", + " │ └─Projection_54(Probe) 1.00 root ->Column#11", " │ └─Limit_55 1.00 root offset:0, count:1", " │ └─TableReader_61 1.00 root data:Limit_60", " │ └─Limit_60 1.00 cop[tikv] offset:0, count:1", diff --git a/planner/cascades/stringer_test.go b/planner/cascades/stringer_test.go index 85314983548c9..c9a64772d8145 100644 --- a/planner/cascades/stringer_test.go +++ b/planner/cascades/stringer_test.go @@ -86,6 +86,6 @@ func (s *testStringerSuite) TestGroupStringer(c *C) { output[i].SQL = sql output[i].Result = ToString(group) }) - c.Assert(ToString(group), DeepEquals, output[i].Result) + c.Assert(ToString(group), DeepEquals, output[i].Result, Commentf("case:%v, sql:%s", i, sql)) } } diff --git a/planner/cascades/testdata/integration_suite_out.json b/planner/cascades/testdata/integration_suite_out.json index c3df6622d4dcb..ea64acb66b7be 100644 --- a/planner/cascades/testdata/integration_suite_out.json +++ b/planner/cascades/testdata/integration_suite_out.json @@ -1019,6 +1019,7 @@ { "SQL": "select sum(a), (select t1.a from t1 where t1.a = t2.a limit 1), (select t1.b from t1 where t1.b = t2.b limit 1) from t2", "Plan": [ +<<<<<<< HEAD "Projection_28 1.00 root Column#3, test.t1.a, test.t1.b", "└─Apply_30 1.00 root CARTESIAN left outer join", " ├─Apply_32(Build) 1.00 root CARTESIAN left outer join", @@ -1038,6 +1039,27 @@ " └─Limit_49 1.00 cop[tikv] offset:0, count:1", " └─Selection_50 1.00 cop[tikv] eq(test.t1.b, test.t2.b)", " └─TableFullScan_51 1.00 cop[tikv] table:t1 keep order:false, stats:pseudo" +======= + "Projection_30 1.00 root Column#7, test.t1.a, test.t1.b", + "└─Apply_32 1.00 root CARTESIAN left outer join", + " ├─Apply_34(Build) 1.00 root CARTESIAN left outer join", + " │ ├─HashAgg_39(Build) 1.00 root funcs:sum(Column#8)->Column#7, funcs:firstrow(Column#9)->test.t2.a, funcs:firstrow(Column#10)->test.t2.b", + " │ │ └─TableReader_40 1.00 root data:HashAgg_41", + " │ │ └─HashAgg_41 1.00 cop[tikv] funcs:sum(test.t2.a)->Column#8, funcs:firstrow(test.t2.a)->Column#9, funcs:firstrow(test.t2.b)->Column#10", + " │ │ └─TableFullScan_38 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo", + " │ └─MaxOneRow_42(Probe) 1.00 root ", + " │ └─Limit_43 1.00 root offset:0, count:1", + " │ └─TableReader_44 1.00 root data:Limit_45", + " │ └─Limit_45 1.00 cop[tikv] offset:0, count:1", + " │ └─Selection_46 1.00 cop[tikv] eq(test.t1.a, test.t2.a)", + " │ └─TableFullScan_47 1.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + " └─MaxOneRow_48(Probe) 1.00 root ", + " └─Limit_49 1.00 root offset:0, count:1", + " └─TableReader_50 1.00 root data:Limit_51", + " └─Limit_51 1.00 cop[tikv] offset:0, count:1", + " └─Selection_52 1.00 cop[tikv] eq(test.t1.b, test.t2.b)", + " └─TableFullScan_53 1.00 cop[tikv] table:t1 keep order:false, stats:pseudo" +>>>>>>> f687ebd91... planner: fix correlated aggregates which should be evaluated in outer query (#21431) ], "Result": [ "6 1 11" diff --git a/planner/cascades/testdata/stringer_suite_out.json b/planner/cascades/testdata/stringer_suite_out.json index 61bd297511785..dbcc00aec395f 100644 --- a/planner/cascades/testdata/stringer_suite_out.json +++ b/planner/cascades/testdata/stringer_suite_out.json @@ -290,7 +290,7 @@ "SQL": "select a = (select a from t t2 where t1.b = t2.b order by a limit 1) from t t1", "Result": [ "Group#0 Schema:[Column#25]", - " Projection_2 input:[Group#1], eq(test.t.a, test.t.a)->Column#25", + " Projection_3 input:[Group#1], eq(test.t.a, test.t.a)->Column#25", "Group#1 Schema:[test.t.a,test.t.b,test.t.a]", " Apply_9 input:[Group#2,Group#3], left outer join", "Group#2 Schema:[test.t.a,test.t.b], UniqueKey:[test.t.a]", diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 5ae9a20ebd4aa..6edd702fa8a8b 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -328,11 +328,24 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { if er.aggrMap != nil { index, ok = er.aggrMap[v] } - if !ok { - er.err = ErrInvalidGroupFuncUse + if ok { + // index < 0 indicates this is a correlated aggregate belonging to outer query, + // for which a correlated column will be created later, so we append a null constant + // as a temporary result expression. + if index < 0 { + er.ctxStackAppend(expression.NewNull(), types.EmptyName) + } else { + // index >= 0 indicates this is a regular aggregate column + er.ctxStackAppend(er.schema.Columns[index], er.names[index]) + } return inNode, true } - er.ctxStackAppend(er.schema.Columns[index], er.names[index]) + // replace correlated aggregate in sub-query with its corresponding correlated column + if col, ok := er.b.correlatedAggMapper[v]; ok { + er.ctxStackAppend(col, types.EmptyName) + return inNode, true + } + er.err = ErrInvalidGroupFuncUse return inNode, true case *ast.ColumnNameExpr: if index, ok := er.b.colMapper[v]; ok { diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 5be43dcaa8779..1122d314d89d8 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -1612,6 +1612,219 @@ func (s *testIntegrationSuite) TestUpdateMultiUpdatePK(c *C) { tk.MustExec(`UPDATE t m, t n SET m.b = m.b + 10, n.b = n.b + 10`) tk.MustQuery("SELECT * FROM t").Check(testkit.Rows("1 12")) +<<<<<<< HEAD tk.MustExec(`UPDATE t m, t n SET m.a = m.a + 1, n.b = n.b + 10`) tk.MustQuery("SELECT * FROM t").Check(testkit.Rows("2 12")) +======= + tk.MustGetErrMsg(`UPDATE t m, t n SET m.a = m.a + 1, n.b = n.b + 10`, + `[planner:1706]Primary key/partition key update is not allowed since the table is updated both as 'm' and 'n'.`) + tk.MustGetErrMsg(`UPDATE t m, t n, t q SET m.a = m.a + 1, n.b = n.b + 10, q.b = q.b - 10`, + `[planner:1706]Primary key/partition key update is not allowed since the table is updated both as 'm' and 'n'.`) + tk.MustGetErrMsg(`UPDATE t m, t n, t q SET m.b = m.b + 1, n.a = n.a + 10, q.b = q.b - 10`, + `[planner:1706]Primary key/partition key update is not allowed since the table is updated both as 'm' and 'n'.`) + tk.MustGetErrMsg(`UPDATE t m, t n, t q SET m.b = m.b + 1, n.b = n.b + 10, q.a = q.a - 10`, + `[planner:1706]Primary key/partition key update is not allowed since the table is updated both as 'm' and 'q'.`) + tk.MustGetErrMsg(`UPDATE t q, t n, t m SET m.b = m.b + 1, n.b = n.b + 10, q.a = q.a - 10`, + `[planner:1706]Primary key/partition key update is not allowed since the table is updated both as 'q' and 'n'.`) + + tk.MustExec("update t m, t n set m.a = n.a+10 where m.a=n.a") + tk.MustQuery("select * from t").Check(testkit.Rows("11 12")) +} + +func (s *testIntegrationSuite) TestOrderByHavingNotInSelect(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists ttest") + tk.MustExec("create table ttest (v1 int, v2 int)") + tk.MustExec("insert into ttest values(1, 2), (4,6), (1, 7)") + tk.MustGetErrMsg("select v1 from ttest order by count(v2)", + "[planner:3029]Expression #1 of ORDER BY contains aggregate function and applies to the result of a non-aggregated query") + tk.MustGetErrMsg("select v1 from ttest having count(v2)", + "[planner:8123]In aggregated query without GROUP BY, expression #1 of SELECT list contains nonaggregated column 'v1'; this is incompatible with sql_mode=only_full_group_by") +} + +func (s *testIntegrationSuite) TestUpdateSetDefault(c *C) { + // #20598 + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table tt (x int, z int as (x+10) stored)") + tk.MustExec("insert into tt(x) values (1)") + tk.MustExec("update tt set x=2, z = default") + tk.MustQuery("select * from tt").Check(testkit.Rows("2 12")) + + tk.MustGetErrMsg("update tt set z = 123", + "[planner:3105]The value specified for generated column 'z' in table 'tt' is not allowed.") + tk.MustGetErrMsg("update tt as ss set z = 123", + "[planner:3105]The value specified for generated column 'z' in table 'tt' is not allowed.") + tk.MustGetErrMsg("update tt as ss set x = 3, z = 13", + "[planner:3105]The value specified for generated column 'z' in table 'tt' is not allowed.") + tk.MustGetErrMsg("update tt as s1, tt as s2 set s1.z = default, s2.z = 456", + "[planner:3105]The value specified for generated column 'z' in table 'tt' is not allowed.") +} + +func (s *testIntegrationSuite) TestOrderByNotInSelectDistinct(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + + // #12442 + tk.MustExec("drop table if exists ttest") + tk.MustExec("create table ttest (v1 int, v2 int)") + tk.MustExec("insert into ttest values(1, 2), (4,6), (1, 7)") + + tk.MustGetErrMsg("select distinct v1 from ttest order by v2", + "[planner:3065]Expression #1 of ORDER BY clause is not in SELECT list, references column 'test.ttest.v2' which is not in SELECT list; this is incompatible with DISTINCT") + tk.MustGetErrMsg("select distinct v1+1 from ttest order by v1", + "[planner:3065]Expression #1 of ORDER BY clause is not in SELECT list, references column 'test.ttest.v1' which is not in SELECT list; this is incompatible with DISTINCT") + tk.MustGetErrMsg("select distinct v1+1 from ttest order by 1+v1", + "[planner:3065]Expression #1 of ORDER BY clause is not in SELECT list, references column 'test.ttest.v1' which is not in SELECT list; this is incompatible with DISTINCT") + tk.MustGetErrMsg("select distinct v1+1 from ttest order by v1+2", + "[planner:3065]Expression #1 of ORDER BY clause is not in SELECT list, references column 'test.ttest.v1' which is not in SELECT list; this is incompatible with DISTINCT") + tk.MustGetErrMsg("select distinct count(v1) from ttest group by v2 order by sum(v1)", + "[planner:3066]Expression #1 of ORDER BY clause is not in SELECT list, contains aggregate function; this is incompatible with DISTINCT") + tk.MustGetErrMsg("select distinct sum(v1)+1 from ttest group by v2 order by sum(v1)", + "[planner:3066]Expression #1 of ORDER BY clause is not in SELECT list, contains aggregate function; this is incompatible with DISTINCT") + + // Expressions in ORDER BY whole match some fields in DISTINCT. + tk.MustQuery("select distinct v1+1 from ttest order by v1+1").Check(testkit.Rows("2", "5")) + tk.MustQuery("select distinct count(v1) from ttest order by count(v1)").Check(testkit.Rows("3")) + tk.MustQuery("select distinct count(v1) from ttest group by v2 order by count(v1)").Check(testkit.Rows("1")) + tk.MustQuery("select distinct sum(v1) from ttest group by v2 order by sum(v1)").Check(testkit.Rows("1", "4")) + tk.MustQuery("select distinct v1, v2 from ttest order by 1, 2").Check(testkit.Rows("1 2", "1 7", "4 6")) + tk.MustQuery("select distinct v1, v2 from ttest order by 2, 1").Check(testkit.Rows("1 2", "4 6", "1 7")) + + // Referenced columns of expressions in ORDER BY whole match some fields in DISTINCT, + // both original expression and alias can be referenced. + tk.MustQuery("select distinct v1 from ttest order by v1+1").Check(testkit.Rows("1", "4")) + tk.MustQuery("select distinct v1, v2 from ttest order by v1+1, v2").Check(testkit.Rows("1 2", "1 7", "4 6")) + tk.MustQuery("select distinct v1+1 as z, v2 from ttest order by v1+1, z+v2").Check(testkit.Rows("2 2", "2 7", "5 6")) + tk.MustQuery("select distinct sum(v1) as z from ttest group by v2 order by z+1").Check(testkit.Rows("1", "4")) + tk.MustQuery("select distinct sum(v1)+1 from ttest group by v2 order by sum(v1)+1").Check(testkit.Rows("2", "5")) + tk.MustQuery("select distinct v1 as z from ttest order by v1+z").Check(testkit.Rows("1", "4")) +} + +func (s *testIntegrationSuite) TestCorrelatedAggregate(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + + // #18350 + tk.MustExec("DROP TABLE IF EXISTS tab, tab2") + tk.MustExec("CREATE TABLE tab(i INT)") + tk.MustExec("CREATE TABLE tab2(j INT)") + tk.MustExec("insert into tab values(1),(2),(3)") + tk.MustExec("insert into tab2 values(1),(2),(3),(15)") + tk.MustQuery(`SELECT m.i, + (SELECT COUNT(n.j) + FROM tab2 WHERE j=15) AS o + FROM tab m, tab2 n GROUP BY 1 order by m.i`).Check(testkit.Rows("1 4", "2 4", "3 4")) + tk.MustQuery(`SELECT + (SELECT COUNT(n.j) + FROM tab2 WHERE j=15) AS o + FROM tab m, tab2 n order by m.i`).Check(testkit.Rows("12")) + + // #17748 + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1 (a int, b int)") + tk.MustExec("create table t2 (m int, n int)") + tk.MustExec("insert into t1 values (2,2), (2,2), (3,3), (3,3), (3,3), (4,4)") + tk.MustExec("insert into t2 values (1,11), (2,22), (3,32), (4,44), (4,44)") + tk.MustExec("set @@sql_mode='TRADITIONAL'") + + tk.MustQuery(`select count(*) c, a, + ( select group_concat(count(a)) from t2 where m = a ) + from t1 group by a order by a`). + Check(testkit.Rows("2 2 2", "3 3 3", "1 4 1,1")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int, b int)") + tk.MustExec("insert into t values (1,1),(2,1),(2,2),(3,1),(3,2),(3,3)") + + // Sub-queries in SELECT fields + // from SELECT fields + tk.MustQuery("select (select count(a)) from t").Check(testkit.Rows("6")) + tk.MustQuery("select (select (select (select count(a)))) from t").Check(testkit.Rows("6")) + tk.MustQuery("select (select (select count(n.a)) from t m order by count(m.b)) from t n").Check(testkit.Rows("6")) + // from WHERE + tk.MustQuery("select (select count(n.a) from t where count(n.a)=3) from t n").Check(testkit.Rows("")) + tk.MustQuery("select (select count(a) from t where count(distinct n.a)=3) from t n").Check(testkit.Rows("6")) + // from HAVING + tk.MustQuery("select (select count(n.a) from t having count(n.a)=6 limit 1) from t n").Check(testkit.Rows("6")) + tk.MustQuery("select (select count(n.a) from t having count(distinct n.b)=3 limit 1) from t n").Check(testkit.Rows("6")) + tk.MustQuery("select (select sum(distinct n.a) from t having count(distinct n.b)=3 limit 1) from t n").Check(testkit.Rows("6")) + tk.MustQuery("select (select sum(distinct n.a) from t having count(distinct n.b)=6 limit 1) from t n").Check(testkit.Rows("")) + // from ORDER BY + tk.MustQuery("select (select count(n.a) from t order by count(n.b) limit 1) from t n").Check(testkit.Rows("6")) + tk.MustQuery("select (select count(distinct n.b) from t order by count(n.b) limit 1) from t n").Check(testkit.Rows("3")) + // from TableRefsClause + tk.MustQuery("select (select cnt from (select count(a) cnt) s) from t").Check(testkit.Rows("6")) + tk.MustQuery("select (select count(cnt) from (select count(a) cnt) s) from t").Check(testkit.Rows("1")) + // from sub-query inside aggregate + tk.MustQuery("select (select sum((select count(a)))) from t").Check(testkit.Rows("6")) + tk.MustQuery("select (select sum((select count(a))+sum(a))) from t").Check(testkit.Rows("20")) + // from GROUP BY + tk.MustQuery("select (select count(a) from t group by count(n.a)) from t n").Check(testkit.Rows("6")) + tk.MustQuery("select (select count(distinct a) from t group by count(n.a)) from t n").Check(testkit.Rows("3")) + + // Sub-queries in HAVING + tk.MustQuery("select sum(a) from t having (select count(a)) = 0").Check(testkit.Rows()) + tk.MustQuery("select sum(a) from t having (select count(a)) > 0").Check(testkit.Rows("14")) + + // Sub-queries in ORDER BY + tk.MustQuery("select count(a) from t group by b order by (select count(a))").Check(testkit.Rows("1", "2", "3")) + tk.MustQuery("select count(a) from t group by b order by (select -count(a))").Check(testkit.Rows("3", "2", "1")) + + // Nested aggregate (correlated aggregate inside aggregate) + tk.MustQuery("select (select sum(count(a))) from t").Check(testkit.Rows("6")) + tk.MustQuery("select (select sum(sum(a))) from t").Check(testkit.Rows("14")) + + // Combining aggregates + tk.MustQuery("select count(a), (select count(a)) from t").Check(testkit.Rows("6 6")) + tk.MustQuery("select sum(distinct b), count(a), (select count(a)), (select cnt from (select sum(distinct b) as cnt) n) from t"). + Check(testkit.Rows("6 6 6 6")) +} + +func (s *testIntegrationSuite) TestCorrelatedColumnAggFuncPushDown(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t (a int, b int);") + tk.MustExec("insert into t values (1,1);") + tk.MustQuery("select (select count(n.a + a) from t) from t n;").Check(testkit.Rows( + "1", + )) +} + +// Test for issue https://github.com/pingcap/tidb/issues/21607. +func (s *testIntegrationSuite) TestConditionColPruneInPhysicalUnionScan(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t (a int, b int);") + tk.MustExec("begin;") + tk.MustExec("insert into t values (1, 2);") + tk.MustQuery("select count(*) from t where b = 1 and b in (3);"). + Check(testkit.Rows("0")) + + tk.MustExec("drop table t;") + tk.MustExec("create table t (a int, b int as (a + 1), c int as (b + 1));") + tk.MustExec("begin;") + tk.MustExec("insert into t (a) values (1);") + tk.MustQuery("select count(*) from t where b = 1 and b in (3);"). + Check(testkit.Rows("0")) + tk.MustQuery("select count(*) from t where c = 1 and c in (3);"). + Check(testkit.Rows("0")) +} + +// Test for issue https://github.com/pingcap/tidb/issues/18320 +func (s *testIntegrationSuite) TestNonaggregateColumnWithSingleValueInOnlyFullGroupByMode(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int, b int, c int)") + tk.MustExec("insert into t values (1, 2, 3), (4, 5, 6), (7, 8, 9)") + tk.MustQuery("select a, count(b) from t where a = 1").Check(testkit.Rows("1 1")) + tk.MustQuery("select a, count(b) from t where a = 10").Check(testkit.Rows(" 0")) + tk.MustQuery("select a, c, sum(b) from t where a = 1 group by c").Check(testkit.Rows("1 3 2")) + tk.MustGetErrMsg("select a from t where a = 1 order by count(b)", "[planner:3029]Expression #1 of ORDER BY contains aggregate function and applies to the result of a non-aggregated query") + tk.MustQuery("select a from t where a = 1 having count(b) > 0").Check(testkit.Rows("1")) +>>>>>>> f687ebd91... planner: fix correlated aggregates which should be evaluated in outer query (#21431) } diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 5d397e6fd8abb..5223491308ae7 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -172,7 +172,8 @@ func (a *aggOrderByResolver) Leave(inNode ast.Node) (ast.Node, bool) { return inNode, true } -func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFuncList []*ast.AggregateFuncExpr, gbyItems []expression.Expression) (LogicalPlan, map[int]int, error) { +func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFuncList []*ast.AggregateFuncExpr, gbyItems []expression.Expression, + correlatedAggMap map[*ast.AggregateFuncExpr]int) (LogicalPlan, map[int]int, error) { b.optFlag |= flagBuildKeyInfo b.optFlag |= flagPushDownAgg // We may apply aggregation eliminate optimization. @@ -232,23 +233,40 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu newFunc.OrderByItems = append(newFunc.OrderByItems, &util.ByItems{Expr: newByItem, Desc: byItem.Desc}) } } + // combine identical aggregate functions combined := false - for j, oldFunc := range plan4Agg.AggFuncs { + for j := 0; j < i; j++ { + oldFunc := plan4Agg.AggFuncs[aggIndexMap[j]] if oldFunc.Equal(b.ctx, newFunc) { - aggIndexMap[i] = j + aggIndexMap[i] = aggIndexMap[j] combined = true + if _, ok := correlatedAggMap[aggFunc]; ok { + if _, ok = b.correlatedAggMapper[aggFuncList[j]]; !ok { + b.correlatedAggMapper[aggFuncList[j]] = &expression.CorrelatedColumn{ + Column: *schema4Agg.Columns[aggIndexMap[j]], + } + } + b.correlatedAggMapper[aggFunc] = b.correlatedAggMapper[aggFuncList[j]] + } break } } + // create new columns for aggregate functions which show up first if !combined { position := len(plan4Agg.AggFuncs) aggIndexMap[i] = position plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, newFunc) - schema4Agg.Append(&expression.Column{ + column := expression.Column{ UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), RetType: newFunc.RetTp, - }) + } + schema4Agg.Append(&column) names = append(names, types.EmptyName) + if _, ok := correlatedAggMap[aggFunc]; ok { + b.correlatedAggMapper[aggFunc] = &expression.CorrelatedColumn{ + Column: column, + } + } } } for i, col := range p.Schema().Columns { @@ -278,6 +296,34 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu return plan4Agg, aggIndexMap, nil } +func (b *PlanBuilder) buildTableRefsWithCache(ctx context.Context, from *ast.TableRefsClause) (p LogicalPlan, err error) { + return b.buildTableRefs(ctx, from, true) +} + +func (b *PlanBuilder) buildTableRefs(ctx context.Context, from *ast.TableRefsClause, useCache bool) (p LogicalPlan, err error) { + if from == nil { + p = b.buildTableDual() + return + } + if !useCache { + return b.buildResultSetNode(ctx, from.TableRefs) + } + var ok bool + p, ok = b.cachedResultSetNodes[from.TableRefs] + if ok { + m := b.cachedHandleHelperMap[from.TableRefs] + b.handleHelper.pushMap(m) + return + } + p, err = b.buildResultSetNode(ctx, from.TableRefs) + if err != nil { + return nil, err + } + b.cachedResultSetNodes[from.TableRefs] = p + b.cachedHandleHelperMap[from.TableRefs] = b.handleHelper.tailMap() + return +} + func (b *PlanBuilder) buildResultSetNode(ctx context.Context, node ast.ResultSetNode) (p LogicalPlan, err error) { switch x := node.(type) { case *ast.Join: @@ -1814,8 +1860,8 @@ func (b *PlanBuilder) resolveHavingAndOrderBy(sel *ast.SelectStmt, p LogicalPlan return havingAggMapper, extractor.aggMapper, nil } -func (b *PlanBuilder) extractAggFuncs(fields []*ast.SelectField) ([]*ast.AggregateFuncExpr, map[*ast.AggregateFuncExpr]int) { - extractor := &AggregateFuncExtractor{} +func (b *PlanBuilder) extractAggFuncsInSelectFields(fields []*ast.SelectField) ([]*ast.AggregateFuncExpr, map[*ast.AggregateFuncExpr]int) { + extractor := &AggregateFuncExtractor{skipAggMap: b.correlatedAggMapper} for _, f := range fields { n, _ := f.Expr.Accept(extractor) f.Expr = n.(ast.ExprNode) @@ -1829,6 +1875,38 @@ func (b *PlanBuilder) extractAggFuncs(fields []*ast.SelectField) ([]*ast.Aggrega return aggList, totalAggMapper } +func (b *PlanBuilder) extractAggFuncsInByItems(byItems []*ast.ByItem) []*ast.AggregateFuncExpr { + extractor := &AggregateFuncExtractor{skipAggMap: b.correlatedAggMapper} + for _, f := range byItems { + n, _ := f.Expr.Accept(extractor) + f.Expr = n.(ast.ExprNode) + } + return extractor.AggFuncs +} + +// extractCorrelatedAggFuncs extracts correlated aggregates which belong to outer query from aggregate function list. +func (b *PlanBuilder) extractCorrelatedAggFuncs(ctx context.Context, p LogicalPlan, aggFuncs []*ast.AggregateFuncExpr) (outer []*ast.AggregateFuncExpr, err error) { + corCols := make([]*expression.CorrelatedColumn, 0, len(aggFuncs)) + cols := make([]*expression.Column, 0, len(aggFuncs)) + aggMapper := make(map[*ast.AggregateFuncExpr]int) + for _, agg := range aggFuncs { + for _, arg := range agg.Args { + expr, _, err := b.rewrite(ctx, arg, p, aggMapper, true) + if err != nil { + return nil, err + } + corCols = append(corCols, expression.ExtractCorColumns(expr)...) + cols = append(cols, expression.ExtractColumns(expr)...) + } + if len(corCols) > 0 && len(cols) == 0 { + outer = append(outer, agg) + } + aggMapper[agg] = -1 + corCols, cols = corCols[:0], cols[:0] + } + return +} + // resolveWindowFunction will process window functions and resolve the columns that don't exist in select fields. func (b *PlanBuilder) resolveWindowFunction(sel *ast.SelectStmt, p LogicalPlan) ( map[*ast.AggregateFuncExpr]int, error) { @@ -1874,15 +1952,232 @@ func (b *PlanBuilder) resolveWindowFunction(sel *ast.SelectStmt, p LogicalPlan) return extractor.aggMapper, nil } +// correlatedAggregateResolver visits Expr tree. +// It finds and collects all correlated aggregates which should be evaluated in the outer query. +type correlatedAggregateResolver struct { + ctx context.Context + err error + b *PlanBuilder + outerPlan LogicalPlan + + // correlatedAggFuncs stores aggregate functions which belong to outer query + correlatedAggFuncs []*ast.AggregateFuncExpr +} + +// Enter implements Visitor interface. +func (r *correlatedAggregateResolver) Enter(n ast.Node) (ast.Node, bool) { + switch v := n.(type) { + case *ast.SelectStmt: + if r.outerPlan != nil { + outerSchema := r.outerPlan.Schema() + r.b.outerSchemas = append(r.b.outerSchemas, outerSchema) + r.b.outerNames = append(r.b.outerNames, r.outerPlan.OutputNames()) + } + r.err = r.resolveSelect(v) + return n, true + } + return n, false +} + +// resolveSelect finds and collects correlated aggregates within the SELECT stmt. +// It resolves and builds FROM clause first to get a source plan, from which we can decide +// whether a column is correlated or not. +// Then it collects correlated aggregate from SELECT fields (including sub-queries), HAVING, +// ORDER BY, WHERE & GROUP BY. +// Finally it restore the original SELECT stmt. +func (r *correlatedAggregateResolver) resolveSelect(sel *ast.SelectStmt) (err error) { + // collect correlated aggregate from sub-queries inside FROM clause. + useCache, err := r.collectFromTableRefs(r.ctx, sel.From) + if err != nil { + return err + } + // we cannot use cache if there are correlated aggregates inside FROM clause, + // since the plan we are building now is not correct and need to be rebuild later. + p, err := r.b.buildTableRefs(r.ctx, sel.From, useCache) + if err != nil { + return err + } + + // similar to process in PlanBuilder.buildSelect + originalFields := sel.Fields.Fields + sel.Fields.Fields, err = r.b.unfoldWildStar(p, sel.Fields.Fields) + if err != nil { + return err + } + if r.b.capFlag&canExpandAST != 0 { + originalFields = sel.Fields.Fields + } + + hasWindowFuncField := r.b.detectSelectWindow(sel) + if hasWindowFuncField { + _, err = r.b.resolveWindowFunction(sel, p) + if err != nil { + return err + } + } + + _, _, err = r.b.resolveHavingAndOrderBy(sel, p) + if err != nil { + return err + } + + // find and collect correlated aggregates recursively in sub-queries + _, err = r.b.resolveCorrelatedAggregates(r.ctx, sel, p) + if err != nil { + return err + } + + // collect from SELECT fields, HAVING, ORDER BY and window functions + if r.b.detectSelectAgg(sel) { + err = r.collectFromSelectFields(p, sel.Fields.Fields) + if err != nil { + return err + } + } + + // collect from WHERE + err = r.collectFromWhere(p, sel.Where) + if err != nil { + return err + } + + // collect from GROUP BY + err = r.collectFromGroupBy(p, sel.GroupBy) + if err != nil { + return err + } + + // restore the sub-query + sel.Fields.Fields = originalFields + r.b.handleHelper.popMap() + return nil +} + +func (r *correlatedAggregateResolver) collectFromTableRefs(ctx context.Context, from *ast.TableRefsClause) (canCache bool, err error) { + if from == nil { + return true, nil + } + subResolver := &correlatedAggregateResolver{ + ctx: r.ctx, + b: r.b, + } + _, ok := from.TableRefs.Accept(subResolver) + if !ok { + return false, subResolver.err + } + if len(subResolver.correlatedAggFuncs) == 0 { + return true, nil + } + r.correlatedAggFuncs = append(r.correlatedAggFuncs, subResolver.correlatedAggFuncs...) + return false, nil +} + +func (r *correlatedAggregateResolver) collectFromSelectFields(p LogicalPlan, fields []*ast.SelectField) error { + aggList, _ := r.b.extractAggFuncsInSelectFields(fields) + r.b.curClause = fieldList + outerAggFuncs, err := r.b.extractCorrelatedAggFuncs(r.ctx, p, aggList) + if err != nil { + return nil + } + r.correlatedAggFuncs = append(r.correlatedAggFuncs, outerAggFuncs...) + return nil +} + +func (r *correlatedAggregateResolver) collectFromGroupBy(p LogicalPlan, groupBy *ast.GroupByClause) error { + if groupBy == nil { + return nil + } + aggList := r.b.extractAggFuncsInByItems(groupBy.Items) + r.b.curClause = groupByClause + outerAggFuncs, err := r.b.extractCorrelatedAggFuncs(r.ctx, p, aggList) + if err != nil { + return nil + } + r.correlatedAggFuncs = append(r.correlatedAggFuncs, outerAggFuncs...) + return nil +} + +func (r *correlatedAggregateResolver) collectFromWhere(p LogicalPlan, where ast.ExprNode) error { + if where == nil { + return nil + } + extractor := &AggregateFuncExtractor{skipAggMap: r.b.correlatedAggMapper} + _, _ = where.Accept(extractor) + r.b.curClause = whereClause + outerAggFuncs, err := r.b.extractCorrelatedAggFuncs(r.ctx, p, extractor.AggFuncs) + if err != nil { + return err + } + r.correlatedAggFuncs = append(r.correlatedAggFuncs, outerAggFuncs...) + return nil +} + +// Leave implements Visitor interface. +func (r *correlatedAggregateResolver) Leave(n ast.Node) (ast.Node, bool) { + switch n.(type) { + case *ast.SelectStmt: + if r.outerPlan != nil { + r.b.outerSchemas = r.b.outerSchemas[0 : len(r.b.outerSchemas)-1] + r.b.outerNames = r.b.outerNames[0 : len(r.b.outerNames)-1] + } + } + return n, true +} + +// resolveCorrelatedAggregates finds and collects all correlated aggregates which should be evaluated +// in the outer query from all the sub-queries inside SELECT fields. +func (b *PlanBuilder) resolveCorrelatedAggregates(ctx context.Context, sel *ast.SelectStmt, p LogicalPlan) (map[*ast.AggregateFuncExpr]int, error) { + resolver := &correlatedAggregateResolver{ + ctx: ctx, + b: b, + outerPlan: p, + } + correlatedAggList := make([]*ast.AggregateFuncExpr, 0) + for _, field := range sel.Fields.Fields { + _, ok := field.Expr.Accept(resolver) + if !ok { + return nil, resolver.err + } + correlatedAggList = append(correlatedAggList, resolver.correlatedAggFuncs...) + } + if sel.Having != nil { + _, ok := sel.Having.Expr.Accept(resolver) + if !ok { + return nil, resolver.err + } + correlatedAggList = append(correlatedAggList, resolver.correlatedAggFuncs...) + } + if sel.OrderBy != nil { + for _, item := range sel.OrderBy.Items { + _, ok := item.Expr.Accept(resolver) + if !ok { + return nil, resolver.err + } + correlatedAggList = append(correlatedAggList, resolver.correlatedAggFuncs...) + } + } + correlatedAggMap := make(map[*ast.AggregateFuncExpr]int) + for _, aggFunc := range correlatedAggList { + correlatedAggMap[aggFunc] = len(sel.Fields.Fields) + sel.Fields.Fields = append(sel.Fields.Fields, &ast.SelectField{ + Auxiliary: true, + Expr: aggFunc, + AsName: model.NewCIStr(fmt.Sprintf("sel_subq_agg_%d", len(sel.Fields.Fields))), + }) + } + return correlatedAggMap, nil +} + // gbyResolver resolves group by items from select fields. type gbyResolver struct { - ctx sessionctx.Context - fields []*ast.SelectField - schema *expression.Schema - names []*types.FieldName - err error - inExpr bool - isParam bool + ctx sessionctx.Context + fields []*ast.SelectField + schema *expression.Schema + names []*types.FieldName + err error + inExpr bool + isParam bool + skipAggMap map[*ast.AggregateFuncExpr]*expression.CorrelatedColumn exprDepth int // exprDepth is the depth of current expression in expression tree. } @@ -1910,7 +2205,7 @@ func (g *gbyResolver) Enter(inNode ast.Node) (ast.Node, bool) { } func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { - extractor := &AggregateFuncExtractor{} + extractor := &AggregateFuncExtractor{skipAggMap: g.skipAggMap} switch v := inNode.(type) { case *ast.ColumnNameExpr: idx, err := expression.FindFieldName(g.names, v.Name) @@ -2355,10 +2650,11 @@ func (b *PlanBuilder) resolveGbyExprs(ctx context.Context, p LogicalPlan, gby *a b.curClause = groupByClause exprs := make([]expression.Expression, 0, len(gby.Items)) resolver := &gbyResolver{ - ctx: b.ctx, - fields: fields, - schema: p.Schema(), - names: p.OutputNames(), + ctx: b.ctx, + fields: fields, + schema: p.Schema(), + names: p.OutputNames(), + skipAggMap: b.correlatedAggMapper, } for _, item := range gby.Items { resolver.inExpr = false @@ -2654,16 +2950,16 @@ func (b *PlanBuilder) buildSelect(ctx context.Context, sel *ast.SelectStmt) (p L aggFuncs []*ast.AggregateFuncExpr havingMap, orderMap, totalMap map[*ast.AggregateFuncExpr]int windowAggMap map[*ast.AggregateFuncExpr]int + correlatedAggMap map[*ast.AggregateFuncExpr]int gbyCols []expression.Expression ) - if sel.From != nil { - p, err = b.buildResultSetNode(ctx, sel.From.TableRefs) - if err != nil { - return nil, err - } - } else { - p = b.buildTableDual() + // For sub-queries, the FROM clause may have already been built in outer query when resolving correlated aggregates. + // If the ResultSetNode inside FROM clause has nothing to do with correlated aggregates, we can simply get the + // existing ResultSetNode from the cache. + p, err = b.buildTableRefsWithCache(ctx, sel.From) + if err != nil { + return nil, err } originalFields := sel.Fields.Fields @@ -2704,6 +3000,15 @@ func (b *PlanBuilder) buildSelect(ctx context.Context, sel *ast.SelectStmt) (p L return nil, err } + // We have to resolve correlated aggregate inside sub-queries before building aggregation and building projection, + // for instance, count(a) inside the sub-query of "select (select count(a)) from t" should be evaluated within + // the context of the outer query. So we have to extract such aggregates from sub-queries and put them into + // SELECT field list. + correlatedAggMap, err = b.resolveCorrelatedAggregates(ctx, sel, p) + if err != nil { + return nil, err + } + // b.allNames will be used in evalDefaultExpr(). Default function is special because it needs to find the // corresponding column name, but does not need the value in the column. // For example, select a from t order by default(b), the column b will not be in select fields. Also because @@ -2732,14 +3037,22 @@ func (b *PlanBuilder) buildSelect(ctx context.Context, sel *ast.SelectStmt) (p L hasAgg := b.detectSelectAgg(sel) if hasAgg { - aggFuncs, totalMap = b.extractAggFuncs(sel.Fields.Fields) + aggFuncs, totalMap = b.extractAggFuncsInSelectFields(sel.Fields.Fields) + // len(aggFuncs) == 0 and sel.GroupBy == nil indicates that all the aggregate functions inside the SELECT fields + // are actually correlated aggregates from the outer query, which have already been built in the outer query. + // The only thing we need to do is to find them from b.correlatedAggMap in buildProjection. + if len(aggFuncs) == 0 && sel.GroupBy == nil { + hasAgg = false + } + } + if hasAgg { var aggIndexMap map[int]int - p, aggIndexMap, err = b.buildAggregation(ctx, p, aggFuncs, gbyCols) + p, aggIndexMap, err = b.buildAggregation(ctx, p, aggFuncs, gbyCols, correlatedAggMap) if err != nil { return nil, err } - for k, v := range totalMap { - totalMap[k] = aggIndexMap[v] + for agg, idx := range totalMap { + totalMap[agg] = aggIndexMap[idx] } } diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index 6d5891c658c16..3988320b8fd67 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -538,7 +538,7 @@ func (s *testPlanSuite) TestColumnPruning(c *C) { ctx := context.Background() for i, tt := range input { - comment := Commentf("for %s", tt) + comment := Commentf("case:%v sql:\"%s\"", i, tt) stmt, err := s.ParseOneStmt(tt, "", "") c.Assert(err, IsNil, comment) @@ -618,6 +618,7 @@ func (s *testPlanSuite) checkDataSourceCols(p LogicalPlan, c *C, ans map[int][]s ans[p.ID()] = make([]string, p.Schema().Len()) }) colList, ok := ans[p.ID()] +<<<<<<< HEAD c.Assert(ok, IsTrue, Commentf("For %v DataSource ID %d Not found", comment, p.ID())) c.Assert(len(p.Schema().Columns), Equals, len(colList), comment) for i, col := range p.Schema().Columns { @@ -632,6 +633,9 @@ func (s *testPlanSuite) checkDataSourceCols(p LogicalPlan, c *C, ans map[int][]s }) colList, ok := ans[p.ID()] c.Assert(ok, IsTrue, Commentf("For %v UnionAll ID %d Not found", comment, p.ID())) +======= + c.Assert(ok, IsTrue, Commentf("For %s %T ID %d Not found", comment.CheckCommentString(), v, p.ID())) +>>>>>>> f687ebd91... planner: fix correlated aggregates which should be evaluated in outer query (#21431) c.Assert(len(p.Schema().Columns), Equals, len(colList), comment) for i, col := range p.Schema().Columns { s.testData.OnRecord(func() { @@ -1665,3 +1669,50 @@ func (s *testPlanSuite) TestSimplyOuterJoinWithOnlyOuterExpr(c *C) { // previous wrong JoinType is InnerJoin c.Assert(join.JoinType, Equals, RightOuterJoin) } + +func (s *testPlanSuite) TestResolvingCorrelatedAggregate(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + sql string + best string + }{ + { + sql: "select (select count(a)) from t", + best: "Apply{DataScan(t)->Aggr(count(test.t.a))->Dual->Projection->MaxOneRow}->Projection", + }, + { + sql: "select (select count(n.a) from t) from t n", + best: "Apply{DataScan(n)->Aggr(count(test.t.a))->DataScan(t)->Projection->MaxOneRow}->Projection", + }, + { + sql: "select (select sum(count(a))) from t", + best: "Apply{DataScan(t)->Aggr(count(test.t.a))->Dual->Aggr(sum(Column#13))->MaxOneRow}->Projection", + }, + { + sql: "select (select sum(count(n.a)) from t) from t n", + best: "Apply{DataScan(n)->Aggr(count(test.t.a))->DataScan(t)->Aggr(sum(Column#25))->MaxOneRow}->Projection", + }, + { + sql: "select (select cnt from (select count(a) as cnt) n) from t", + best: "Apply{DataScan(t)->Aggr(count(test.t.a))->Dual->Projection->MaxOneRow}->Projection", + }, + { + sql: "select sum(a), sum(a), count(a), (select count(a)) from t", + best: "Apply{DataScan(t)->Aggr(sum(test.t.a),count(test.t.a))->Dual->Projection->MaxOneRow}->Projection", + }, + } + + ctx := context.TODO() + for i, tt := range tests { + comment := Commentf("case:%v sql:%s", i, tt.sql) + stmt, err := s.ParseOneStmt(tt.sql, "", "") + c.Assert(err, IsNil, comment) + err = Preprocess(s.ctx, stmt, s.is) + c.Assert(err, IsNil, comment) + p, _, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) + c.Assert(err, IsNil, comment) + p, err = logicalOptimize(context.TODO(), flagBuildKeyInfo|flagEliminateProjection|flagPrunColumns|flagPrunColumnsAgain, p.(LogicalPlan)) + c.Assert(err, IsNil, comment) + c.Assert(ToString(p), Equals, tt.best, comment) + } +} diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index c412d2b36a09e..52caabd3502be 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -452,6 +452,19 @@ type PlanBuilder struct { // evalDefaultExpr needs this information to find the corresponding column. // It stores the OutputNames before buildProjection. allNames [][]*types.FieldName +<<<<<<< HEAD +======= + + // isSampling indicates whether the query is sampling. + isSampling bool + + // correlatedAggMapper stores columns for correlated aggregates which should be evaluated in outer query. + correlatedAggMapper map[*ast.AggregateFuncExpr]*expression.CorrelatedColumn + + // cache ResultSetNodes and HandleHelperMap to avoid rebuilding. + cachedResultSetNodes map[*ast.Join]LogicalPlan + cachedHandleHelperMap map[*ast.Join]map[int64][]HandleCols +>>>>>>> f687ebd91... planner: fix correlated aggregates which should be evaluated in outer query (#21431) } type handleColHelper struct { @@ -550,12 +563,24 @@ func NewPlanBuilder(sctx sessionctx.Context, is infoschema.InfoSchema, processor sctx.GetSessionVars().PlannerSelectBlockAsName = make([]ast.HintTable, processor.MaxSelectStmtOffset()+1) } return &PlanBuilder{ +<<<<<<< HEAD ctx: sctx, is: is, colMapper: make(map[*ast.ColumnNameExpr]int), handleHelper: &handleColHelper{id2HandleMapStack: make([]map[int64][]*expression.Column, 0)}, hintProcessor: processor, } +======= + ctx: sctx, + is: is, + colMapper: make(map[*ast.ColumnNameExpr]int), + handleHelper: &handleColHelper{id2HandleMapStack: make([]map[int64][]HandleCols, 0)}, + hintProcessor: processor, + correlatedAggMapper: make(map[*ast.AggregateFuncExpr]*expression.CorrelatedColumn), + cachedResultSetNodes: make(map[*ast.Join]LogicalPlan), + cachedHandleHelperMap: make(map[*ast.Join]map[int64][]HandleCols), + }, savedBlockNames +>>>>>>> f687ebd91... planner: fix correlated aggregates which should be evaluated in outer query (#21431) } // Build builds the ast node to a Plan. diff --git a/planner/core/testdata/plan_suite_out.json b/planner/core/testdata/plan_suite_out.json index 21faa82b7adc8..d3a36a913ad0c 100644 --- a/planner/core/testdata/plan_suite_out.json +++ b/planner/core/testdata/plan_suite_out.json @@ -2084,12 +2084,12 @@ { "SQL": "select 1 from (select /*+ HASH_JOIN(t1) */ t1.a in (select t2.a from t2) from t1) x;", "Plan": "LeftHashJoin{IndexReader(Index(t1.idx_a)[[NULL,+inf]])->IndexReader(Index(t2.idx_a)[[NULL,+inf]])}->Projection", - "Hints": "use_index(@`sel_2` `test`.`t1` `idx_a`), use_index(@`sel_3` `test`.`t2` `idx_a`), hash_join(@`sel_2` `test`.`t1`)" + "Hints": "use_index(@`sel_2` `test`.`t1` `idx_a`), use_index(@`sel_2` `test`.`t2` `idx_a`), hash_join(@`sel_2` `test`.`t1`)" }, { "SQL": "select 1 from (select /*+ HASH_JOIN(t1) */ t1.a not in (select t2.a from t2) from t1) x;", "Plan": "LeftHashJoin{IndexReader(Index(t1.idx_a)[[NULL,+inf]])->IndexReader(Index(t2.idx_a)[[NULL,+inf]])}->Projection", - "Hints": "use_index(@`sel_2` `test`.`t1` `idx_a`), use_index(@`sel_3` `test`.`t2` `idx_a`), hash_join(@`sel_2` `test`.`t1`)" + "Hints": "use_index(@`sel_2` `test`.`t1` `idx_a`), use_index(@`sel_2` `test`.`t2` `idx_a`), hash_join(@`sel_2` `test`.`t1`)" }, { "SQL": "select /*+ INL_JOIN(t1) */ t1.b, t2.b from t1 inner join t2 on t1.a = t2.a;", diff --git a/planner/core/testdata/plan_suite_unexported_out.json b/planner/core/testdata/plan_suite_unexported_out.json index 90047adc595a5..99c2c6a0237c6 100644 --- a/planner/core/testdata/plan_suite_unexported_out.json +++ b/planner/core/testdata/plan_suite_unexported_out.json @@ -101,7 +101,7 @@ "Cases": [ "Join{DataScan(t)->DataScan(s)}(test.t.a,test.t.a)->Projection", "Join{DataScan(t)->Aggr(count(test.t.c),firstrow(test.t.a))->DataScan(s)}(test.t.a,test.t.a)->Projection->Projection", - "Join{DataScan(t)->Aggr(count(test.t.c),firstrow(test.t.a))->DataScan(s)}(test.t.a,test.t.a)->Aggr(firstrow(Column#13),count(test.t.b))->Projection->Projection", + "Join{DataScan(t)->Aggr(count(test.t.c),firstrow(test.t.a))->DataScan(s)}(test.t.a,test.t.a)->Aggr(firstrow(Column#25),count(test.t.b))->Projection->Projection", "Apply{DataScan(t)->DataScan(s)->Sel([eq(test.t.a, test.t.a)])->Aggr(count(test.t.b))}->Projection", "Join{DataScan(t)->DataScan(s)->Aggr(count(test.t.b),firstrow(test.t.a))}(test.t.a,test.t.a)->Projection->Projection->Projection", "Join{Join{DataScan(t1)->DataScan(t2)}->DataScan(s)->Aggr(count(test.t.b),firstrow(test.t.a))}(test.t.a,test.t.a)->Projection->Projection->Projection", @@ -632,7 +632,7 @@ "1": [ "test.t.a" ], - "3": [ + "2": [ "test.t.a", "test.t.b" ] @@ -647,7 +647,7 @@ "test.t.a", "test.t.b" ], - "3": [ + "2": [ "test.t.b" ] }, @@ -655,7 +655,7 @@ "1": [ "test.t.a" ], - "3": [ + "2": [ "test.t.b" ] }, diff --git a/planner/core/util.go b/planner/core/util.go index 84b77e779d2b8..9ccccf658a183 100644 --- a/planner/core/util.go +++ b/planner/core/util.go @@ -27,9 +27,11 @@ import ( ) // AggregateFuncExtractor visits Expr tree. -// It converts ColunmNameExpr to AggregateFuncExpr and collects AggregateFuncExpr. +// It collects AggregateFuncExpr from AST Node. type AggregateFuncExtractor struct { - inAggregateFuncExpr bool + // skipAggMap stores correlated aggregate functions which have been built in outer query, + // so extractor in sub-query will skip these aggregate functions. + skipAggMap map[*ast.AggregateFuncExpr]*expression.CorrelatedColumn // AggFuncs is the collected AggregateFuncExprs. AggFuncs []*ast.AggregateFuncExpr } @@ -37,9 +39,13 @@ type AggregateFuncExtractor struct { // Enter implements Visitor interface. func (a *AggregateFuncExtractor) Enter(n ast.Node) (ast.Node, bool) { switch n.(type) { +<<<<<<< HEAD case *ast.AggregateFuncExpr: a.inAggregateFuncExpr = true case *ast.SelectStmt, *ast.UnionStmt: +======= + case *ast.SelectStmt, *ast.SetOprStmt: +>>>>>>> f687ebd91... planner: fix correlated aggregates which should be evaluated in outer query (#21431) return n, true } return n, false @@ -49,8 +55,9 @@ func (a *AggregateFuncExtractor) Enter(n ast.Node) (ast.Node, bool) { func (a *AggregateFuncExtractor) Leave(n ast.Node) (ast.Node, bool) { switch v := n.(type) { case *ast.AggregateFuncExpr: - a.inAggregateFuncExpr = false - a.AggFuncs = append(a.AggFuncs, v) + if _, ok := a.skipAggMap[v]; !ok { + a.AggFuncs = append(a.AggFuncs, v) + } } return n, true }