diff --git a/expression/integration_test.go b/expression/integration_test.go index 145214e576a39..4a35f580e30c9 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -9216,6 +9216,29 @@ func (s *testIntegrationSuite) TestRefineArgNullValues(c *C) { )) } +func (s *testIntegrationSuite) TestIssue26958(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1 (c_int int not null);") + tk.MustExec("insert into t1 values (1), (2), (3),(1),(2),(3);") + tk.MustExec("drop table if exists t2;") + tk.MustExec("create table t2 (c_int int not null);") + tk.MustExec("insert into t2 values (1), (2), (3),(1),(2),(3);") + tk.MustQuery("select \n(select count(distinct c_int) from t2 where c_int >= t1.c_int) c1, \n(select count(distinct c_int) from t2 where c_int >= t1.c_int) c2\nfrom t1 group by c_int;\n"). + Check(testkit.Rows("3 3", "2 2", "1 1")) +} + +func (s *testIntegrationSuite) TestIssue27233(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t;") + tk.MustExec("CREATE TABLE `t` (\n `COL1` tinyint(45) NOT NULL,\n `COL2` tinyint(45) NOT NULL,\n PRIMARY KEY (`COL1`,`COL2`) /*T![clustered_index] NONCLUSTERED */\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;") + tk.MustExec("insert into t values(122,100),(124,-22),(124,34),(127,103);") + tk.MustQuery("SELECT col2 FROM t AS T1 WHERE ( SELECT count(DISTINCT COL1, COL2) FROM t AS T2 WHERE T2.COL1 > T1.COL1 ) > 2 ;"). + Check(testkit.Rows("100")) +} + func (s *testIntegrationSerialSuite) TestIssue26662(c *C) { collate.SetNewCollationEnabledForTest(true) defer collate.SetNewCollationEnabledForTest(false) diff --git a/planner/core/physical_plan_test.go b/planner/core/physical_plan_test.go index b326bbfa29265..83848b8087e73 100644 --- a/planner/core/physical_plan_test.go +++ b/planner/core/physical_plan_test.go @@ -1694,6 +1694,39 @@ func (s *testPlanSuite) TestNthPlanHintWithExplain(c *C) { tk.MustQuery("explain format = 'brief' select * from test.tt where a=1 and b=1").Check(testkit.Rows(output[1].Plan...)) } +func (s *testPlanSuite) TestIssue27233(c *C) { + var ( + input []string + output []struct { + SQL string + Plan []string + Result []string + } + ) + s.testData.GetTestCases(c, &input, &output) + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + defer func() { + dom.Close() + store.Close() + }() + tk := testkit.NewTestKit(c, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("CREATE TABLE `PK_S_MULTI_31` (\n `COL1` tinyint(45) NOT NULL,\n `COL2` tinyint(45) NOT NULL,\n PRIMARY KEY (`COL1`,`COL2`) /*T![clustered_index] NONCLUSTERED */\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;") + tk.MustExec("insert into PK_S_MULTI_31 values(122,100),(124,-22),(124,34),(127,103);") + + for i, ts := range input { + s.testData.OnRecord(func() { + output[i].SQL = ts + output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery("explain format='brief'" + ts).Rows()) + output[i].Result = s.testData.ConvertRowsToStrings(tk.MustQuery(ts).Sort().Rows()) + }) + tk.MustQuery("explain format='brief' " + ts).Check(testkit.Rows(output[i].Plan...)) + tk.MustQuery(ts).Sort().Check(testkit.Rows(output[i].Result...)) + } +} + func (s *testPlanSuite) TestPossibleProperties(c *C) { store, dom, err := newStoreWithBootstrap() c.Assert(err, IsNil) diff --git a/planner/core/rule_decorrelate.go b/planner/core/rule_decorrelate.go index bf7e332c5eb33..1e97d6ff21454 100644 --- a/planner/core/rule_decorrelate.go +++ b/planner/core/rule_decorrelate.go @@ -189,22 +189,28 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan) (Logica resetNotNullFlag(apply.schema, outerPlan.Schema().Len(), apply.schema.Len()) for i, aggFunc := range agg.AggFuncs { - switch expr := aggFunc.Args[0].(type) { - case *expression.Column: - if idx := apply.schema.ColumnIndex(expr); idx != -1 { - desc, err := aggregation.NewAggFuncDesc(agg.ctx, agg.AggFuncs[i].Name, []expression.Expression{apply.schema.Columns[idx]}, false) - if err != nil { - return nil, err + aggArgs := make([]expression.Expression, 0, len(aggFunc.Args)) + for _, arg := range aggFunc.Args { + switch expr := arg.(type) { + case *expression.Column: + if idx := apply.schema.ColumnIndex(expr); idx != -1 { + aggArgs = append(aggArgs, apply.schema.Columns[idx]) + } else { + aggArgs = append(aggArgs, expr) } - newAggFuncs = append(newAggFuncs, desc) + case *expression.ScalarFunction: + expr.RetType = expr.RetType.Clone() + expr.RetType.Flag &= ^mysql.NotNullFlag + aggArgs = append(aggArgs, expr) + default: + aggArgs = append(aggArgs, expr) } - case *expression.ScalarFunction: - expr.RetType = expr.RetType.Clone() - expr.RetType.Flag &= ^mysql.NotNullFlag - newAggFuncs = append(newAggFuncs, aggFunc) - default: - newAggFuncs = append(newAggFuncs, aggFunc) } + desc, err := aggregation.NewAggFuncDesc(agg.ctx, agg.AggFuncs[i].Name, aggArgs, agg.AggFuncs[i].HasDistinct) + if err != nil { + return nil, err + } + newAggFuncs = append(newAggFuncs, desc) } agg.AggFuncs = newAggFuncs np, err := s.optimize(ctx, p) diff --git a/planner/core/testdata/plan_suite_in.json b/planner/core/testdata/plan_suite_in.json index ea257f22a332c..69ef82b9862c0 100644 --- a/planner/core/testdata/plan_suite_in.json +++ b/planner/core/testdata/plan_suite_in.json @@ -676,5 +676,11 @@ "select a from t2 where t2.a < (select t1.a from t1 where t1.b = t2.b and t1.a is null);", "select a from t2 where t2.a < (select t3.a from t3 where t3.a = t2.a);" ] + }, + { + "name": "TestIssue27233", + "cases": [ + "SELECT col2 FROM PK_S_MULTI_31 AS T1 WHERE (SELECT count(DISTINCT COL1, COL2) FROM PK_S_MULTI_31 AS T2 WHERE T2.COL1>T1.COL1)>2 order by col2;" + ] } ] diff --git a/planner/core/testdata/plan_suite_out.json b/planner/core/testdata/plan_suite_out.json index 977ae8eeb64b2..a45e766c2b098 100644 --- a/planner/core/testdata/plan_suite_out.json +++ b/planner/core/testdata/plan_suite_out.json @@ -2277,5 +2277,27 @@ "Result": null } ] + }, + { + "Name": "TestIssue27233", + "Cases": [ + { + "SQL": "SELECT col2 FROM PK_S_MULTI_31 AS T1 WHERE (SELECT count(DISTINCT COL1, COL2) FROM PK_S_MULTI_31 AS T2 WHERE T2.COL1>T1.COL1)>2 order by col2;", + "Plan": [ + "Sort 0.80 root test.pk_s_multi_31.col2", + "└─Projection 0.80 root test.pk_s_multi_31.col2", + " └─Selection 0.80 root gt(Column#7, 2)", + " └─HashAgg 1.00 root group by:test.pk_s_multi_31.col1, test.pk_s_multi_31.col2, funcs:firstrow(test.pk_s_multi_31.col2)->test.pk_s_multi_31.col2, funcs:count(distinct test.pk_s_multi_31.col1, test.pk_s_multi_31.col2)->Column#7", + " └─HashJoin 100000000.00 root CARTESIAN left outer join, other cond:gt(test.pk_s_multi_31.col1, test.pk_s_multi_31.col1)", + " ├─IndexReader(Build) 10000.00 root index:IndexFullScan", + " │ └─IndexFullScan 10000.00 cop[tikv] table:T2, index:PRIMARY(COL1, COL2) keep order:false, stats:pseudo", + " └─IndexReader(Probe) 10000.00 root index:IndexFullScan", + " └─IndexFullScan 10000.00 cop[tikv] table:T1, index:PRIMARY(COL1, COL2) keep order:false, stats:pseudo" + ], + "Result": [ + "100" + ] + } + ] } ]