diff --git a/executor/executor_test.go b/executor/executor_test.go index 3412567e0a53b..1ae0306ad0911 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -1021,6 +1021,21 @@ func (s *testSuite) TestUnion(c *C) { tk.MustQuery("select 1 union select 1 union all select 1").Check(testkit.Rows("1", "1")) tk.MustQuery("select 1 union all select 1 union select 1").Check(testkit.Rows("1")) + + tk.MustExec("drop table if exists t1, t2") + tk.MustExec(`create table t1(a bigint, b bigint);`) + tk.MustExec(`create table t2(a bigint, b bigint);`) + tk.MustExec(`insert into t1 values(1, 1);`) + tk.MustExec(`insert into t1 select * from t1;`) + tk.MustExec(`insert into t1 select * from t1;`) + tk.MustExec(`insert into t1 select * from t1;`) + tk.MustExec(`insert into t1 select * from t1;`) + tk.MustExec(`insert into t1 select * from t1;`) + tk.MustExec(`insert into t1 select * from t1;`) + tk.MustExec(`insert into t2 values(1, 1);`) + tk.MustExec(`set @@tidb_max_chunk_size=2;`) + tk.MustQuery(`select count(*) from (select t1.a, t1.b from t1 left join t2 on t1.a=t2.a union all select t1.a, t1.a from t1 left join t2 on t1.a=t2.a) tmp;`).Check(testkit.Rows("128")) + tk.MustQuery(`select tmp.a, count(*) from (select t1.a, t1.b from t1 left join t2 on t1.a=t2.a union all select t1.a, t1.a from t1 left join t2 on t1.a=t2.a) tmp;`).Check(testkit.Rows("1 128")) } func (s *testSuite) TestIn(c *C) { diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index 933c93b206900..afa552019588e 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -129,6 +129,15 @@ func (ls *LogicalSort) PruneColumns(parentUsedCols []*expression.Column) { // PruneColumns implements LogicalPlan interface. func (p *LogicalUnionAll) PruneColumns(parentUsedCols []*expression.Column) { + used := getUsedList(parentUsedCols, p.schema) + hasBeenUsed := false + for i := range used { + hasBeenUsed = hasBeenUsed || used[i] + } + if !hasBeenUsed { + parentUsedCols = make([]*expression.Column, len(p.schema.Columns)) + copy(parentUsedCols, p.schema.Columns) + } for _, child := range p.Children() { child.PruneColumns(parentUsedCols) }