Skip to content

Commit 6e0ef86

Browse files
bersprocketscloud-fan
authored andcommitted
[SPARK-40382][SQL] Group distinct aggregate expressions by semantically equivalent children in RewriteDistinctAggregates
### What changes were proposed in this pull request? In `RewriteDistinctAggregates`, when grouping aggregate expressions by function children, treat children that are semantically equivalent as the same. ### Why are the changes needed? This PR will reduce the number of projections in the Expand operator when there are multiple distinct aggregations with superficially different children. In some cases, it will eliminate the need for an Expand operator. Example: In the following query, the Expand operator creates 3\*n rows (where n is the number of incoming rows) because it has a projection for each of function children `b + 1`, `1 + b` and `c`. ``` create or replace temp view v1 as select * from values (1, 2, 3.0), (1, 3, 4.0), (2, 4, 2.5), (2, 3, 1.0) v1(a, b, c); select a, count(distinct b + 1), avg(distinct 1 + b) filter (where c > 0), sum(c) from v1 group by a; ``` The Expand operator has three projections (each producing a row for each incoming row): ``` [a#87, null, null, 0, null, UnscaledValue(c#89)], <== projection #1 (for regular aggregation) [a#87, (b#88 + 1), null, 1, null, null], <== projection #2 (for distinct aggregation of b + 1) [a#87, null, (1 + b#88), 2, (c#89 > 0.0), null]], <== projection #3 (for distinct aggregation of 1 + b) ``` In reality, the Expand only needs one projection for `1 + b` and `b + 1`, because they are semantically equivalent. With the proposed change, the Expand operator's projections look like this: ``` [a#67, null, 0, null, UnscaledValue(c#69)], <== projection #1 (for regular aggregations) [a#67, (b#68 + 1), 1, (c#69 > 0.0), null]], <== projection #2 (for distinct aggregation on b + 1 and 1 + b) ``` With one less projection, Expand produces 2\*n rows instead of 3\*n rows, but still produces the correct result. In the case where all distinct aggregates have semantically equivalent children, the Expand operator is not needed at all. Benchmark code in the JIRA (SPARK-40382). Before the PR: ``` distinct aggregates: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ all semantically equivalent 14721 14859 195 5.7 175.5 1.0X some semantically equivalent 14569 14572 5 5.8 173.7 1.0X none semantically equivalent 14408 14488 113 5.8 171.8 1.0X ``` After the PR: ``` distinct aggregates: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ all semantically equivalent 3658 3692 49 22.9 43.6 1.0X some semantically equivalent 9124 9214 127 9.2 108.8 0.4X none semantically equivalent 14601 14777 250 5.7 174.1 0.3X ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit tests. Closes #37825 from bersprockets/rewritedistinct_issue. Authored-by: Bruce Robbins <bersprockets@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 9bc8c06 commit 6e0ef86

File tree

6 files changed

+87
-9
lines changed

6 files changed

+87
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
220220

221221
// Extract distinct aggregate expressions.
222222
val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e =>
223-
val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet
223+
val unfoldableChildren = ExpressionSet(e.aggregateFunction.children.filter(!_.foldable))
224224
if (unfoldableChildren.nonEmpty) {
225225
// Only expand the unfoldable children
226226
unfoldableChildren
@@ -231,7 +231,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
231231
// count(distinct 1) will be explained to count(1) after the rewrite function.
232232
// Generally, the distinct aggregateFunction should not run
233233
// foldable TypeCheck for the first child.
234-
e.aggregateFunction.children.take(1).toSet
234+
ExpressionSet(e.aggregateFunction.children.take(1))
235235
}
236236
}
237237

@@ -254,7 +254,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
254254

255255
// Setup unique distinct aggregate children.
256256
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
257-
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
257+
val distinctAggChildAttrMap = distinctAggChildren.map { e =>
258+
e.canonicalized -> AttributeReference(e.sql, e.dataType, nullable = true)()
259+
}
258260
val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
259261
// Setup all the filters in distinct aggregate.
260262
val (distinctAggFilters, distinctAggFilterAttrs, maxConds) = distinctAggs.collect {
@@ -292,7 +294,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
292294
af
293295
} else {
294296
patchAggregateFunctionChildren(af) { x =>
295-
distinctAggChildAttrLookup.get(x)
297+
distinctAggChildAttrLookup.get(x.canonicalized)
296298
}
297299
}
298300
val newCondition = if (condition.isDefined) {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,37 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
7575
.analyze
7676
checkRewrite(RewriteDistinctAggregates(input))
7777
}
78+
79+
test("SPARK-40382: eliminate multiple distinct groups due to superficial differences") {
80+
val input = testRelation
81+
.groupBy($"a")(
82+
countDistinct($"b" + $"c").as("agg1"),
83+
countDistinct($"c" + $"b").as("agg2"),
84+
max($"c").as("agg3"))
85+
.analyze
86+
87+
val rewrite = RewriteDistinctAggregates(input)
88+
rewrite match {
89+
case Aggregate(_, _, LocalRelation(_, _, _)) =>
90+
case _ => fail(s"Plan is not as expected:\n$rewrite")
91+
}
92+
}
93+
94+
test("SPARK-40382: reduce multiple distinct groups due to superficial differences") {
95+
val input = testRelation
96+
.groupBy($"a")(
97+
countDistinct($"b" + $"c" + $"d").as("agg1"),
98+
countDistinct($"d" + $"c" + $"b").as("agg2"),
99+
countDistinct($"b" + $"c").as("agg3"),
100+
countDistinct($"c" + $"b").as("agg4"),
101+
max($"c").as("agg5"))
102+
.analyze
103+
104+
val rewrite = RewriteDistinctAggregates(input)
105+
rewrite match {
106+
case Aggregate(_, _, Aggregate(_, _, e: Expand)) =>
107+
assert(e.projections.size == 3)
108+
case _ => fail(s"Plan is not rewritten:\n$rewrite")
109+
}
110+
}
78111
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
527527

528528
val (functionsWithDistinct, functionsWithoutDistinct) =
529529
aggregateExpressions.partition(_.isDistinct)
530-
if (functionsWithDistinct.map(
531-
_.aggregateFunction.children.filterNot(_.foldable).toSet).distinct.length > 1) {
530+
val distinctAggChildSets = functionsWithDistinct.map { ae =>
531+
ExpressionSet(ae.aggregateFunction.children.filterNot(_.foldable))
532+
}.distinct
533+
if (distinctAggChildSets.length > 1) {
532534
// This is a sanity check. We should not reach here when we have multiple distinct
533535
// column sets. Our `RewriteDistinctAggregates` should take care this case.
534536
throw new IllegalStateException(

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,14 +219,17 @@ object AggUtils {
219219
}
220220

221221
// 3. Create an Aggregate operator for partial aggregation (for distinct)
222-
val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions, distinctAttributes)
222+
val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions.map(_.canonicalized),
223+
distinctAttributes)
223224
val rewrittenDistinctFunctions = functionsWithDistinct.map {
224225
// Children of an AggregateFunction with DISTINCT keyword has already
225226
// been evaluated. At here, we need to replace original children
226227
// to AttributeReferences.
227228
case agg @ AggregateExpression(aggregateFunction, mode, true, _, _) =>
228-
aggregateFunction.transformDown(distinctColumnAttributeLookup)
229-
.asInstanceOf[AggregateFunction]
229+
aggregateFunction.transformDown {
230+
case e: Expression if distinctColumnAttributeLookup.contains(e.canonicalized) =>
231+
distinctColumnAttributeLookup(e.canonicalized)
232+
}.asInstanceOf[AggregateFunction]
230233
case agg =>
231234
throw new IllegalArgumentException(
232235
"Non-distinct aggregate is found in functionsWithDistinct " +

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,6 +1485,40 @@ class DataFrameAggregateSuite extends QueryTest
14851485
val df = Seq(1).toDF("id").groupBy(Stream($"id" + 1, $"id" + 2): _*).sum("id")
14861486
checkAnswer(df, Row(2, 3, 1))
14871487
}
1488+
1489+
test("SPARK-40382: Distinct aggregation expression grouping by semantic equivalence") {
1490+
Seq(
1491+
(1, 1, 3),
1492+
(1, 2, 3),
1493+
(1, 2, 3),
1494+
(2, 1, 1),
1495+
(2, 2, 5)
1496+
).toDF("k", "c1", "c2").createOrReplaceTempView("df")
1497+
1498+
// all distinct aggregation children are semantically equivalent
1499+
val res1 = sql(
1500+
"""select k, sum(distinct c1 + 1), avg(distinct 1 + c1), count(distinct 1 + C1)
1501+
|from df
1502+
|group by k
1503+
|""".stripMargin)
1504+
checkAnswer(res1, Row(1, 5, 2.5, 2) :: Row(2, 5, 2.5, 2) :: Nil)
1505+
1506+
// some distinct aggregation children are semantically equivalent
1507+
val res2 = sql(
1508+
"""select k, sum(distinct c1 + 2), avg(distinct 2 + c1), count(distinct c2)
1509+
|from df
1510+
|group by k
1511+
|""".stripMargin)
1512+
checkAnswer(res2, Row(1, 7, 3.5, 1) :: Row(2, 7, 3.5, 2) :: Nil)
1513+
1514+
// no distinct aggregation children are semantically equivalent
1515+
val res3 = sql(
1516+
"""select k, sum(distinct c1 + 2), avg(distinct 3 + c1), count(distinct c2)
1517+
|from df
1518+
|group by k
1519+
|""".stripMargin)
1520+
checkAnswer(res3, Row(1, 7, 4.5, 1) :: Row(2, 7, 4.5, 2) :: Nil)
1521+
}
14881522
}
14891523

14901524
case class B(c: Option[Double])

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
9595
// 2 distinct columns with different order
9696
val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i")
9797
assertNoExpand(query3.queryExecution.executedPlan)
98+
99+
// SPARK-40382: 1 distinct expression with cosmetic differences
100+
val query4 = sql("SELECT sum(DISTINCT j), max(DISTINCT J) FROM v GROUP BY i")
101+
assertNoExpand(query4.queryExecution.executedPlan)
98102
}
99103
}
100104

0 commit comments

Comments
 (0)