Skip to content

Commit c8d78a7

Browse files
peter-tothcloud-fan
authored andcommitted
[SPARK-34581][SQL] Don't optimize out grouping expressions from aggregate expressions without aggregate function
### What changes were proposed in this pull request? This PR: - Adds a new expression `GroupingExprRef` that can be used in aggregate expressions of `Aggregate` nodes to refer grouping expressions by index. These expressions capture the data type and nullability of the referred grouping expression. - Adds a new rule `EnforceGroupingReferencesInAggregates` that inserts the references in the beginning of the optimization phase. - Adds a new rule `UpdateGroupingExprRefNullability` to update nullability of `GroupingExprRef` expressions as nullability of referred grouping expression can change during optimization. ### Why are the changes needed? If aggregate expressions (without aggregate functions) in an `Aggregate` node are complex then the `Optimizer` can optimize out grouping expressions from them and so making aggregate expressions invalid. Here is a simple example: ``` SELECT not(t.id IS NULL) , count(*) FROM t GROUP BY t.id IS NULL ``` In this case the `BooleanSimplification` rule does this: ``` === Applying Rule org.apache.spark.sql.catalyst.optimizer.BooleanSimplification === !Aggregate [isnull(id#222)], [NOT isnull(id#222) AS (NOT (id IS NULL))#226, count(1) AS c#224L] Aggregate [isnull(id#222)], [isnotnull(id#222) AS (NOT (id IS NULL))#226, count(1) AS c#224L] +- Project [value#219 AS id#222] +- Project [value#219 AS id#222] +- LocalRelation [value#219] +- LocalRelation [value#219] ``` where `NOT isnull(id#222)` is optimized to `isnotnull(id#222)` and so it no longer refers to any grouping expression. Before this PR: ``` == Optimized Logical Plan == Aggregate [isnull(id#222)], [isnotnull(id#222) AS (NOT (id IS NULL))#234, count(1) AS c#232L] +- Project [value#219 AS id#222] +- LocalRelation [value#219] ``` and running the query throws an error: ``` Couldn't find id#222 in [isnull(id#222)#230,count(1)#226L] java.lang.IllegalStateException: Couldn't find id#222 in [isnull(id#222)#230,count(1)#226L] ``` After this PR: ``` == Optimized Logical Plan == Aggregate [isnull(id#222)], [NOT groupingexprref(0) AS (NOT (id IS NULL))#234, count(1) AS c#232L] +- Project [value#219 AS id#222] +- LocalRelation [value#219] ``` and the query works. ### Does this PR introduce _any_ user-facing change? Yes, the query works. ### How was this patch tested? Added new UT. Closes #31913 from peter-toth/SPARK-34581-keep-grouping-expressions. Authored-by: Peter Toth <peter.toth@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent fd08c93 commit c8d78a7

File tree

15 files changed

+247
-68
lines changed

15 files changed

+247
-68
lines changed
Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.analysis
1919

20-
import org.apache.spark.sql.catalyst.expressions.Attribute
21-
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
20+
import org.apache.spark.sql.catalyst.expressions.{Attribute, GroupingExprRef, NamedExpression}
21+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan}
2222
import org.apache.spark.sql.catalyst.rules.Rule
2323

2424
/**
@@ -52,3 +52,22 @@ object UpdateAttributeNullability extends Rule[LogicalPlan] {
5252
}
5353
}
5454
}
55+
56+
/**
57+
* Updates nullability of [[GroupingExprRef]]s in a resolved LogicalPlan by using the nullability of
58+
* referenced grouping expression.
59+
*/
60+
object UpdateGroupingExprRefNullability extends Rule[LogicalPlan] {
61+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
62+
case a: Aggregate =>
63+
val nullabilities = a.groupingExpressions.map(_.nullable).toArray
64+
65+
val newAggregateExpressions =
66+
a.aggregateExpressions.map(_.transform {
67+
case g: GroupingExprRef if g.nullable != nullabilities(g.ordinal) =>
68+
g.copy(nullable = nullabilities(g.ordinal))
69+
}.asInstanceOf[NamedExpression])
70+
71+
a.copy(aggregateExpressions = newAggregateExpressions)
72+
}
73+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ trait AliasHelper {
3535
protected def getAliasMap(plan: Aggregate): AttributeMap[Alias] = {
3636
// Find all the aliased expressions in the aggregate list that don't include any actual
3737
// AggregateExpression or PythonUDF, and create a map from the alias to the expression
38-
val aliasMap = plan.aggregateExpressions.collect {
38+
val aliasMap = plan.aggregateExpressionsWithoutGroupingRefs.collect {
3939
case a: Alias if a.child.find(e => e.isInstanceOf[AggregateExpression] ||
4040
PythonUDF.isGroupedAggPandasUDF(e)).isEmpty =>
4141
(a.toAttribute, a)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ object AggregateExpression {
8080
filter,
8181
NamedExpression.newExprId)
8282
}
83+
84+
def containsAggregate(expr: Expression): Boolean = {
85+
expr.find(isAggregate).isDefined
86+
}
87+
88+
def isAggregate(expr: Expression): Boolean = {
89+
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
90+
}
8391
}
8492

8593
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,22 @@ object GroupingAnalytics {
277277
}
278278
}
279279
}
280+
281+
/**
282+
* A reference to an grouping expression in [[Aggregate]] node.
283+
*
284+
* @param ordinal The ordinal of the grouping expression in [[Aggregate]] that this expression
285+
* refers to.
286+
* @param dataType The [[DataType]] of the referenced grouping expression.
287+
* @param nullable True if null is a valid value for the referenced grouping expression.
288+
*/
289+
case class GroupingExprRef(
290+
ordinal: Int,
291+
dataType: DataType,
292+
nullable: Boolean)
293+
extends LeafExpression with Unevaluable {
294+
295+
override def stringArgs: Iterator[Any] = {
296+
Iterator(ordinal)
297+
}
298+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,14 @@
1818
package org.apache.spark.sql.catalyst.optimizer
1919

2020
import org.apache.spark.sql.catalyst.expressions._
21-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
21+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2222
import org.apache.spark.sql.catalyst.rules.Rule
2323

2424
/**
2525
* Simplify redundant [[CreateNamedStruct]], [[CreateArray]] and [[CreateMap]] expressions.
2626
*/
2727
object SimplifyExtractValueOps extends Rule[LogicalPlan] {
2828
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
29-
// One place where this optimization is invalid is an aggregation where the select
30-
// list expression is a function of a grouping expression:
31-
//
32-
// SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b)
33-
//
34-
// cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this
35-
// optimization for Aggregates (although this misses some cases where the optimization
36-
// can be made).
37-
case a: Aggregate => a
3829
case p => p.transformExpressionsUp {
3930
// Remove redundant field extraction.
4031
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
21+
import org.apache.spark.sql.catalyst.rules.Rule
22+
23+
/**
24+
* This rule ensures that [[Aggregate]] nodes contain all required [[GroupingExprRef]]
25+
* references for optimization phase.
26+
*/
27+
object EnforceGroupingReferencesInAggregates extends Rule[LogicalPlan] {
28+
override def apply(plan: LogicalPlan): LogicalPlan = {
29+
plan transform {
30+
case a: Aggregate =>
31+
Aggregate.withGroupingRefs(a.groupingExpressions, a.aggregateExpressions, a.child)
32+
}
33+
}
34+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
118118
OptimizeUpdateFields,
119119
SimplifyExtractValueOps,
120120
OptimizeCsvJsonExprs,
121-
CombineConcats) ++
121+
CombineConcats,
122+
UpdateGroupingExprRefNullability) ++
122123
extendedOperatorOptimizationRules
123124

124125
val operatorOptimizationBatch: Seq[Batch] = {
@@ -147,6 +148,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
147148
EliminateView,
148149
ReplaceExpressions,
149150
RewriteNonCorrelatedExists,
151+
EnforceGroupingReferencesInAggregates,
150152
ComputeCurrentTime,
151153
GetCurrentDatabaseAndCatalog(catalogManager)) ::
152154
//////////////////////////////////////////////////////////////////////////////////////////
@@ -266,7 +268,9 @@ abstract class Optimizer(catalogManager: CatalogManager)
266268
RewriteCorrelatedScalarSubquery.ruleName ::
267269
RewritePredicateSubquery.ruleName ::
268270
NormalizeFloatingNumbers.ruleName ::
269-
ReplaceUpdateFieldsExpression.ruleName :: Nil
271+
ReplaceUpdateFieldsExpression.ruleName ::
272+
EnforceGroupingReferencesInAggregates.ruleName ::
273+
UpdateGroupingExprRefNullability.ruleName :: Nil
270274

271275
/**
272276
* Optimize all the subqueries inside expression.
@@ -506,7 +510,7 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
506510
case upper @ Aggregate(_, _, lower: Aggregate) if lowerIsRedundant(upper, lower) =>
507511
val aliasMap = getAliasMap(lower)
508512

509-
val newAggregate = upper.copy(
513+
val newAggregate = Aggregate.withGroupingRefs(
510514
child = lower.child,
511515
groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)),
512516
aggregateExpressions = upper.aggregateExpressions.map(
@@ -522,23 +526,19 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
522526
}
523527

524528
private def lowerIsRedundant(upper: Aggregate, lower: Aggregate): Boolean = {
525-
val upperHasNoAggregateExpressions = !upper.aggregateExpressions.exists(isAggregate)
529+
val upperHasNoAggregateExpressions =
530+
!upper.aggregateExpressions.exists(AggregateExpression.containsAggregate)
526531

527532
lazy val upperRefsOnlyDeterministicNonAgg = upper.references.subsetOf(AttributeSet(
528533
lower
529534
.aggregateExpressions
530535
.filter(_.deterministic)
531-
.filter(!isAggregate(_))
536+
.filterNot(AggregateExpression.containsAggregate)
532537
.map(_.toAttribute)
533538
))
534539

535540
upperHasNoAggregateExpressions && upperRefsOnlyDeterministicNonAgg
536541
}
537-
538-
private def isAggregate(expr: Expression): Boolean = {
539-
expr.find(e => e.isInstanceOf[AggregateExpression] ||
540-
PythonUDF.isGroupedAggPandasUDF(e)).isDefined
541-
}
542542
}
543543

544544
/**
@@ -1979,7 +1979,18 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
19791979
case a @ Aggregate(grouping, _, _) if grouping.nonEmpty =>
19801980
val newGrouping = grouping.filter(!_.foldable)
19811981
if (newGrouping.nonEmpty) {
1982-
a.copy(groupingExpressions = newGrouping)
1982+
val droppedGroupsBefore =
1983+
grouping.scanLeft(0)((n, e) => n + (if (e.foldable) 1 else 0)).toArray
1984+
1985+
val newAggregateExpressions =
1986+
a.aggregateExpressions.map(_.transform {
1987+
case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 =>
1988+
g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal))
1989+
}.asInstanceOf[NamedExpression])
1990+
1991+
a.copy(
1992+
groupingExpressions = newGrouping,
1993+
aggregateExpressions = newAggregateExpressions)
19831994
} else {
19841995
// All grouping expressions are literals. We should not drop them all, because this can
19851996
// change the return semantics when the input of the Aggregate is empty (SPARK-17114). We
@@ -2000,7 +2011,25 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
20002011
if (newGrouping.size == grouping.size) {
20012012
a
20022013
} else {
2003-
a.copy(groupingExpressions = newGrouping)
2014+
var i = 0
2015+
val droppedGroupsBefore = grouping.scanLeft(0)((n, e) =>
2016+
n + (if (i >= newGrouping.size || e.eq(newGrouping(i))) {
2017+
i += 1
2018+
0
2019+
} else {
2020+
1
2021+
})
2022+
).toArray
2023+
2024+
val newAggregateExpressions =
2025+
a.aggregateExpressions.map(_.transform {
2026+
case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 =>
2027+
g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal))
2028+
}.asInstanceOf[NamedExpression])
2029+
2030+
a.copy(
2031+
groupingExpressions = newGrouping,
2032+
aggregateExpressions = newAggregateExpressions)
20042033
}
20052034
}
20062035
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,9 +632,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
632632
* subqueries.
633633
*/
634634
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
635-
case a @ Aggregate(grouping, expressions, child) =>
635+
case a @ Aggregate(grouping, _, child) =>
636636
val subqueries = ArrayBuffer.empty[ScalarSubquery]
637-
val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
637+
val rewriteExprs = a.aggregateExpressionsWithoutGroupingRefs
638+
.map(extractCorrelatedScalarSubqueries(_, subqueries))
638639
if (subqueries.nonEmpty) {
639640
// We currently only allow correlated subqueries in an aggregate if they are part of the
640641
// grouping expressions. As a result we need to replace all the scalar subqueries in the

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ object PhysicalAggregation {
287287
(Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan)
288288

289289
def unapply(a: Any): Option[ReturnType] = a match {
290-
case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
290+
case a @ logical.Aggregate(groupingExpressions, resultExpressions, child) =>
291291
// A single aggregate expression might appear multiple times in resultExpressions.
292292
// In order to avoid evaluating an individual aggregate function multiple times, we'll
293293
// build a set of semantically distinct aggregate expressions and re-write expressions so
@@ -297,11 +297,9 @@ object PhysicalAggregation {
297297
val aggregateExpressions = resultExpressions.flatMap { expr =>
298298
expr.collect {
299299
// addExpr() always returns false for non-deterministic expressions and do not add them.
300-
case agg: AggregateExpression
301-
if !equivalentAggregateExpressions.addExpr(agg) => agg
302-
case udf: PythonUDF
303-
if PythonUDF.isGroupedAggPandasUDF(udf) &&
304-
!equivalentAggregateExpressions.addExpr(udf) => udf
300+
case a
301+
if AggregateExpression.isAggregate(a) && !equivalentAggregateExpressions.addExpr(a) =>
302+
a
305303
}
306304
}
307305

@@ -322,7 +320,7 @@ object PhysicalAggregation {
322320
// which takes the grouping columns and final aggregate result buffer as input.
323321
// Thus, we must re-write the result expressions so that their attributes match up with
324322
// the attributes of the final result projection's input row:
325-
val rewrittenResultExpressions = resultExpressions.map { expr =>
323+
val rewrittenResultExpressions = a.aggregateExpressionsWithoutGroupingRefs.map { expr =>
326324
expr.transformDown {
327325
case ae: AggregateExpression =>
328326
// The final aggregation buffer's attributes will be `finalAggregationAttributes`,

0 commit comments

Comments
 (0)