Skip to content

[SPARK-34581][SQL] Don't optimize out grouping expressions from aggregate expressions without aggregate function #31913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
ae1186f
[SPARK-34581][SQL] Don't optimize out grouping expressions from aggre…
peter-toth Mar 21, 2021
5ab9f75
comment fix
peter-toth Mar 21, 2021
2293fd4
move logic to the beginning of optimization, simplify test
peter-toth Mar 22, 2021
3de19ca
regenerate approved plans
peter-toth Mar 22, 2021
04e61c5
Merge branch 'master' into SPARK-34581-keep-grouping-expressions
peter-toth Mar 23, 2021
6e05f14
define GroupingExpression as TaggingExpression
peter-toth Mar 23, 2021
09f1a85
move test to SQLQueryTestSuite
peter-toth Mar 24, 2021
f46b89d
add more explanation
peter-toth Mar 24, 2021
56589a3
Merge commit 'c8233f1be5c2f853f42cda367475eb135a83afd5' into SPARK-34…
peter-toth Mar 26, 2021
ea95bff
Merge commit '3951e3371a83578a81474ed99fb50d59f27aac62' into SPARK-34…
peter-toth Mar 31, 2021
7ea2306
Merge commit '89ae83d19b9652348a685550c2c49920511160d5' into SPARK-34…
peter-toth Apr 1, 2021
468534f
Merge commit '65da9287bc5112564836a555cd2967fc6b05856f' into SPARK-34…
peter-toth Apr 2, 2021
977c0bf
new GroupingExprRef approach
peter-toth Mar 27, 2021
c2ba804
simplify
peter-toth Apr 11, 2021
0622444
minor fixes
peter-toth Apr 12, 2021
343f35e
Merge commit 'e40fce919ab77f5faeb0bbd34dc86c56c04adbaa' into SPARK-34…
peter-toth Apr 12, 2021
2e79eb9
review fixes
peter-toth Apr 13, 2021
cff9b9a
fix latest test failures, add new test case
peter-toth Apr 14, 2021
78296a8
better non-deterministic test case
peter-toth Apr 14, 2021
72c173b
make new rules non excludable
peter-toth Apr 15, 2021
34f0439
Merge branch 'fork/master' into SPARK-34581-keep-grouping-expressions
peter-toth Apr 15, 2021
fb3a19d
fix validConstraints, minor changes
peter-toth Apr 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
import org.apache.spark.sql.catalyst.expressions.{Attribute, GroupingExprRef, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule

/**
Expand Down Expand Up @@ -52,3 +52,22 @@ object UpdateAttributeNullability extends Rule[LogicalPlan] {
}
}
}

/**
* Updates nullability of [[GroupingExprRef]]s in a resolved LogicalPlan by using the nullability of
* referenced grouping expression.
*/
object UpdateGroupingExprRefNullability extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a: Aggregate =>
val nullabilities = a.groupingExpressions.map(_.nullable).toArray

val newAggregateExpressions =
a.aggregateExpressions.map(_.transform {
case g: GroupingExprRef if g.nullable != nullabilities(g.ordinal) =>
g.copy(nullable = nullabilities(g.ordinal))
}.asInstanceOf[NamedExpression])

a.copy(aggregateExpressions = newAggregateExpressions)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ trait AliasHelper {
protected def getAliasMap(plan: Aggregate): AttributeMap[Alias] = {
// Find all the aliased expressions in the aggregate list that don't include any actual
// AggregateExpression or PythonUDF, and create a map from the alias to the expression
val aliasMap = plan.aggregateExpressions.collect {
val aliasMap = plan.aggregateExpressionsWithoutGroupingRefs.collect {
case a: Alias if a.child.find(e => e.isInstanceOf[AggregateExpression] ||
PythonUDF.isGroupedAggPandasUDF(e)).isEmpty =>
(a.toAttribute, a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ object AggregateExpression {
filter,
NamedExpression.newExprId)
}

def containsAggregate(expr: Expression): Boolean = {
expr.find(isAggregate).isDefined
}

def isAggregate(expr: Expression): Boolean = {
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,22 @@ object GroupingAnalytics {
}
}
}

/**
* A reference to an grouping expression in [[Aggregate]] node.
*
* @param ordinal The ordinal of the grouping expression in [[Aggregate]] that this expression
* refers to.
* @param dataType The [[DataType]] of the referenced grouping expression.
* @param nullable True if null is a valid value for the referenced grouping expression.
*/
case class GroupingExprRef(
ordinal: Int,
dataType: DataType,
nullable: Boolean)
extends LeafExpression with Unevaluable {

override def stringArgs: Iterator[Any] = {
Iterator(ordinal)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,14 @@
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule

/**
* Simplify redundant [[CreateNamedStruct]], [[CreateArray]] and [[CreateMap]] expressions.
*/
object SimplifyExtractValueOps extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// One place where this optimization is invalid is an aggregation where the select
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this limitation now.

// list expression is a function of a grouping expression:
//
// SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b)
//
// cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this
// optimization for Aggregates (although this misses some cases where the optimization
// can be made).
case a: Aggregate => a
case p => p.transformExpressionsUp {
// Remove redundant field extraction.
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule

/**
* This rule ensures that [[Aggregate]] nodes contain all required [[GroupingExprRef]]
* references for optimization phase.
*/
object EnforceGroupingReferencesInAggregates extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan transform {
case a: Aggregate =>
Aggregate.withGroupingRefs(a.groupingExpressions, a.aggregateExpressions, a.child)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
OptimizeUpdateFields,
SimplifyExtractValueOps,
OptimizeCsvJsonExprs,
CombineConcats) ++
CombineConcats,
UpdateGroupingExprRefNullability) ++
extendedOperatorOptimizationRules

val operatorOptimizationBatch: Seq[Batch] = {
Expand Down Expand Up @@ -147,6 +148,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
EliminateView,
ReplaceExpressions,
RewriteNonCorrelatedExists,
EnforceGroupingReferencesInAggregates,
ComputeCurrentTime,
GetCurrentDatabaseAndCatalog(catalogManager)) ::
//////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -266,7 +268,9 @@ abstract class Optimizer(catalogManager: CatalogManager)
RewriteCorrelatedScalarSubquery.ruleName ::
RewritePredicateSubquery.ruleName ::
NormalizeFloatingNumbers.ruleName ::
ReplaceUpdateFieldsExpression.ruleName :: Nil
ReplaceUpdateFieldsExpression.ruleName ::
EnforceGroupingReferencesInAggregates.ruleName ::
UpdateGroupingExprRefNullability.ruleName :: Nil

/**
* Optimize all the subqueries inside expression.
Expand Down Expand Up @@ -506,7 +510,7 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
case upper @ Aggregate(_, _, lower: Aggregate) if lowerIsRedundant(upper, lower) =>
val aliasMap = getAliasMap(lower)

val newAggregate = upper.copy(
val newAggregate = Aggregate.withGroupingRefs(
child = lower.child,
groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)),
aggregateExpressions = upper.aggregateExpressions.map(
Expand All @@ -522,23 +526,19 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
}

private def lowerIsRedundant(upper: Aggregate, lower: Aggregate): Boolean = {
val upperHasNoAggregateExpressions = !upper.aggregateExpressions.exists(isAggregate)
val upperHasNoAggregateExpressions =
!upper.aggregateExpressions.exists(AggregateExpression.containsAggregate)

lazy val upperRefsOnlyDeterministicNonAgg = upper.references.subsetOf(AttributeSet(
lower
.aggregateExpressions
.filter(_.deterministic)
.filter(!isAggregate(_))
.filterNot(AggregateExpression.containsAggregate)
.map(_.toAttribute)
))

upperHasNoAggregateExpressions && upperRefsOnlyDeterministicNonAgg
}

private def isAggregate(expr: Expression): Boolean = {
expr.find(e => e.isInstanceOf[AggregateExpression] ||
PythonUDF.isGroupedAggPandasUDF(e)).isDefined
}
}

/**
Expand Down Expand Up @@ -1979,7 +1979,18 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
case a @ Aggregate(grouping, _, _) if grouping.nonEmpty =>
val newGrouping = grouping.filter(!_.foldable)
if (newGrouping.nonEmpty) {
a.copy(groupingExpressions = newGrouping)
val droppedGroupsBefore =
grouping.scanLeft(0)((n, e) => n + (if (e.foldable) 1 else 0)).toArray

val newAggregateExpressions =
a.aggregateExpressions.map(_.transform {
case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 =>
g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal))
}.asInstanceOf[NamedExpression])

a.copy(
groupingExpressions = newGrouping,
aggregateExpressions = newAggregateExpressions)
} else {
// All grouping expressions are literals. We should not drop them all, because this can
// change the return semantics when the input of the Aggregate is empty (SPARK-17114). We
Expand All @@ -2000,7 +2011,25 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
if (newGrouping.size == grouping.size) {
a
} else {
a.copy(groupingExpressions = newGrouping)
var i = 0
val droppedGroupsBefore = grouping.scanLeft(0)((n, e) =>
n + (if (i >= newGrouping.size || e.eq(newGrouping(i))) {
i += 1
0
} else {
1
})
).toArray

val newAggregateExpressions =
a.aggregateExpressions.map(_.transform {
case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 =>
g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal))
}.asInstanceOf[NamedExpression])

a.copy(
groupingExpressions = newGrouping,
aggregateExpressions = newAggregateExpressions)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
* subqueries.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
case a @ Aggregate(grouping, expressions, child) =>
case a @ Aggregate(grouping, _, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
val rewriteExprs = a.aggregateExpressionsWithoutGroupingRefs
.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
// We currently only allow correlated subqueries in an aggregate if they are part of the
// grouping expressions. As a result we need to replace all the scalar subqueries in the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ object PhysicalAggregation {
(Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan)

def unapply(a: Any): Option[ReturnType] = a match {
case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
case a @ logical.Aggregate(groupingExpressions, resultExpressions, child) =>
// A single aggregate expression might appear multiple times in resultExpressions.
// In order to avoid evaluating an individual aggregate function multiple times, we'll
// build a set of semantically distinct aggregate expressions and re-write expressions so
Expand All @@ -297,11 +297,9 @@ object PhysicalAggregation {
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
// addExpr() always returns false for non-deterministic expressions and do not add them.
case agg: AggregateExpression
if !equivalentAggregateExpressions.addExpr(agg) => agg
case udf: PythonUDF
if PythonUDF.isGroupedAggPandasUDF(udf) &&
!equivalentAggregateExpressions.addExpr(udf) => udf
case a
if AggregateExpression.isAggregate(a) && !equivalentAggregateExpressions.addExpr(a) =>
a
}
}

Expand All @@ -322,7 +320,7 @@ object PhysicalAggregation {
// which takes the grouping columns and final aggregate result buffer as input.
// Thus, we must re-write the result expressions so that their attributes match up with
// the attributes of the final result projection's input row:
val rewrittenResultExpressions = resultExpressions.map { expr =>
val rewrittenResultExpressions = a.aggregateExpressionsWithoutGroupingRefs.map { expr =>
expr.transformDown {
case ae: AggregateExpression =>
// The final aggregation buffer's attributes will be `finalAggregationAttributes`,
Expand Down
Loading