Skip to content

Commit 41a940f

Browse files
c21cloud-fan
authored andcommitted
[SPARK-37557][SQL] Replace object hash with sort aggregate if child is already sorted
### What changes were proposed in this pull request? This is a follow up of #34702 (comment) , where we can replace object hash aggregate with sort aggregate as well. This PR is to handle object hash aggregate. ### Why are the changes needed? Increase coverage of rule by handling object hash aggregate as well. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Modified unit test in `ReplaceHashWithSortAggSuite.scala` to cover object hash aggregate (by using aggregate expression `COLLECT_LIST`). Closes #34824 from c21/agg-rule-followup. Authored-by: Cheng Su <chengsu@fb.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 116255d commit 41a940f

File tree

4 files changed

+94
-69
lines changed

4 files changed

+94
-69
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAgg.scala

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,18 @@ package org.apache.spark.sql.execution
2020
import org.apache.spark.sql.catalyst.expressions.SortOrder
2121
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Final, Partial}
2222
import org.apache.spark.sql.catalyst.rules.Rule
23-
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
23+
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
2424
import org.apache.spark.sql.internal.SQLConf
2525

2626
/**
27-
* Replace [[HashAggregateExec]] with [[SortAggregateExec]] in the spark plan if:
27+
* Replace hash-based aggregate with sort aggregate in the spark plan if:
2828
*
29-
* 1. The plan is a pair of partial and final [[HashAggregateExec]], and the child of partial
30-
* aggregate satisfies the sort order of corresponding [[SortAggregateExec]].
29+
* 1. The plan is a pair of partial and final [[HashAggregateExec]] or [[ObjectHashAggregateExec]],
30+
* and the child of partial aggregate satisfies the sort order of corresponding
31+
* [[SortAggregateExec]].
3132
* or
32-
* 2. The plan is a [[HashAggregateExec]], and the child satisfies the sort order of
33-
* corresponding [[SortAggregateExec]].
33+
* 2. The plan is a [[HashAggregateExec]] or [[ObjectHashAggregateExec]], and the child satisfies
34+
* the sort order of corresponding [[SortAggregateExec]].
3435
*
3536
* Examples:
3637
* 1. aggregate after join:
@@ -47,9 +48,9 @@ import org.apache.spark.sql.internal.SQLConf
4748
* | => |
4849
* Sort(t1.i) Sort(t1.i)
4950
*
50-
* [[HashAggregateExec]] can be replaced when its child satisfies the sort order of
51-
* corresponding [[SortAggregateExec]]. [[SortAggregateExec]] is faster in the sense that
52-
* it does not have hashing overhead of [[HashAggregateExec]].
51+
* Hash-based aggregate can be replaced when its child satisfies the sort order of
52+
* corresponding sort aggregate. Sort aggregate is faster in the sense that
53+
* it does not have hashing overhead of hash aggregate.
5354
*/
5455
object ReplaceHashWithSortAgg extends Rule[SparkPlan] {
5556
def apply(plan: SparkPlan): SparkPlan = {
@@ -61,14 +62,15 @@ object ReplaceHashWithSortAgg extends Rule[SparkPlan] {
6162
}
6263

6364
/**
64-
* Replace [[HashAggregateExec]] with [[SortAggregateExec]].
65+
* Replace [[HashAggregateExec]] and [[ObjectHashAggregateExec]] with [[SortAggregateExec]].
6566
*/
6667
private def replaceHashAgg(plan: SparkPlan): SparkPlan = {
6768
plan.transformDown {
68-
case hashAgg: HashAggregateExec if hashAgg.groupingExpressions.nonEmpty =>
69+
case hashAgg: BaseAggregateExec if isHashBasedAggWithKeys(hashAgg) =>
6970
val sortAgg = hashAgg.toSortAggregate
7071
hashAgg.child match {
71-
case partialAgg: HashAggregateExec if isPartialAgg(partialAgg, hashAgg) =>
72+
case partialAgg: BaseAggregateExec
73+
if isHashBasedAggWithKeys(partialAgg) && isPartialAgg(partialAgg, hashAgg) =>
7274
if (SortOrder.orderingSatisfies(
7375
partialAgg.child.outputOrdering, sortAgg.requiredChildOrdering.head)) {
7476
sortAgg.copy(
@@ -92,7 +94,7 @@ object ReplaceHashWithSortAgg extends Rule[SparkPlan] {
9294
/**
9395
* Check if `partialAgg` to be partial aggregate of `finalAgg`.
9496
*/
95-
private def isPartialAgg(partialAgg: HashAggregateExec, finalAgg: HashAggregateExec): Boolean = {
97+
private def isPartialAgg(partialAgg: BaseAggregateExec, finalAgg: BaseAggregateExec): Boolean = {
9698
if (partialAgg.aggregateExpressions.forall(_.mode == Partial) &&
9799
finalAgg.aggregateExpressions.forall(_.mode == Final)) {
98100
(finalAgg.logicalLink, partialAgg.logicalLink) match {
@@ -103,4 +105,16 @@ object ReplaceHashWithSortAgg extends Rule[SparkPlan] {
103105
false
104106
}
105107
}
108+
109+
/**
110+
* Check if `agg` is [[HashAggregateExec]] or [[ObjectHashAggregateExec]],
111+
* and has grouping keys.
112+
*/
113+
private def isHashBasedAggWithKeys(agg: BaseAggregateExec): Boolean = {
114+
val isHashBasedAgg = agg match {
115+
case _: HashAggregateExec | _: ObjectHashAggregateExec => true
116+
case _ => false
117+
}
118+
isHashBasedAgg && agg.groupingExpressions.nonEmpty
119+
}
106120
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning
3030
def groupingExpressions: Seq[NamedExpression]
3131
def aggregateExpressions: Seq[AggregateExpression]
3232
def aggregateAttributes: Seq[Attribute]
33+
def initialInputBufferOffset: Int
3334
def resultExpressions: Seq[NamedExpression]
3435

3536
override def verboseStringWithOperatorId(): String = {
@@ -95,4 +96,13 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning
9596
case None => UnspecifiedDistribution :: Nil
9697
}
9798
}
99+
100+
/**
101+
* The corresponding [[SortAggregateExec]] to get same result as this node.
102+
*/
103+
def toSortAggregate: SortAggregateExec = {
104+
SortAggregateExec(
105+
requiredChildDistributionExpressions, groupingExpressions, aggregateExpressions,
106+
aggregateAttributes, initialInputBufferOffset, resultExpressions, child)
107+
}
98108
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,15 +1153,6 @@ case class HashAggregateExec(
11531153
}
11541154
}
11551155

1156-
/**
1157-
* The corresponding [[SortAggregateExec]] to get same result as this node.
1158-
*/
1159-
def toSortAggregate: SortAggregateExec = {
1160-
SortAggregateExec(
1161-
requiredChildDistributionExpressions, groupingExpressions, aggregateExpressions,
1162-
aggregateAttributes, initialInputBufferOffset, resultExpressions, child)
1163-
}
1164-
11651156
override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExec =
11661157
copy(child = newChild)
11671158
}

sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala

Lines changed: 57 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.sql.{DataFrame, QueryTest}
2121
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
22-
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
22+
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
2323
import org.apache.spark.sql.internal.SQLConf
2424
import org.apache.spark.sql.test.SharedSparkSession
2525

@@ -30,7 +30,9 @@ abstract class ReplaceHashWithSortAggSuiteBase
3030

3131
private def checkNumAggs(df: DataFrame, hashAggCount: Int, sortAggCount: Int): Unit = {
3232
val plan = df.queryExecution.executedPlan
33-
assert(collectWithSubqueries(plan) { case s: HashAggregateExec => s }.length == hashAggCount)
33+
assert(collectWithSubqueries(plan) {
34+
case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec) => s
35+
}.length == hashAggCount)
3436
assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount)
3537
}
3638

@@ -55,71 +57,79 @@ abstract class ReplaceHashWithSortAggSuiteBase
5557
test("replace partial hash aggregate with sort aggregate") {
5658
withTempView("t") {
5759
spark.range(100).selectExpr("id as key").repartition(10).createOrReplaceTempView("t")
58-
val query =
59-
"""
60-
|SELECT key, FIRST(key)
61-
|FROM
62-
|(
63-
| SELECT key
64-
| FROM t
65-
| WHERE key > 10
66-
| SORT BY key
67-
|)
68-
|GROUP BY key
69-
""".stripMargin
70-
checkAggs(query, 1, 1, 2, 0)
60+
Seq("FIRST", "COLLECT_LIST").foreach { aggExpr =>
61+
val query =
62+
s"""
63+
|SELECT key, $aggExpr(key)
64+
|FROM
65+
|(
66+
| SELECT key
67+
| FROM t
68+
| WHERE key > 10
69+
| SORT BY key
70+
|)
71+
|GROUP BY key
72+
""".stripMargin
73+
checkAggs(query, 1, 1, 2, 0)
74+
}
7175
}
7276
}
7377

7478
test("replace partial and final hash aggregate together with sort aggregate") {
7579
withTempView("t1", "t2") {
7680
spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1")
7781
spark.range(50).selectExpr("id as key").createOrReplaceTempView("t2")
78-
val query =
79-
"""
80-
|SELECT key, COUNT(key)
81-
|FROM
82-
|(
83-
| SELECT /*+ SHUFFLE_MERGE(t1) */ t1.key AS key
84-
| FROM t1
85-
| JOIN t2
86-
| ON t1.key = t2.key
87-
|)
88-
|GROUP BY key
89-
""".stripMargin
90-
checkAggs(query, 0, 1, 2, 0)
82+
Seq("COUNT", "COLLECT_LIST").foreach { aggExpr =>
83+
val query =
84+
s"""
85+
|SELECT key, $aggExpr(key)
86+
|FROM
87+
|(
88+
| SELECT /*+ SHUFFLE_MERGE(t1) */ t1.key AS key
89+
| FROM t1
90+
| JOIN t2
91+
| ON t1.key = t2.key
92+
|)
93+
|GROUP BY key
94+
""".stripMargin
95+
checkAggs(query, 0, 1, 2, 0)
96+
}
9197
}
9298
}
9399

94100
test("do not replace hash aggregate if child does not have sort order") {
95101
withTempView("t1", "t2") {
96102
spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1")
97103
spark.range(50).selectExpr("id as key").createOrReplaceTempView("t2")
98-
val query =
99-
"""
100-
|SELECT key, COUNT(key)
101-
|FROM
102-
|(
103-
| SELECT /*+ BROADCAST(t1) */ t1.key AS key
104-
| FROM t1
105-
| JOIN t2
106-
| ON t1.key = t2.key
107-
|)
108-
|GROUP BY key
109-
""".stripMargin
110-
checkAggs(query, 2, 0, 2, 0)
104+
Seq("COUNT", "COLLECT_LIST").foreach { aggExpr =>
105+
val query =
106+
s"""
107+
|SELECT key, $aggExpr(key)
108+
|FROM
109+
|(
110+
| SELECT /*+ BROADCAST(t1) */ t1.key AS key
111+
| FROM t1
112+
| JOIN t2
113+
| ON t1.key = t2.key
114+
|)
115+
|GROUP BY key
116+
""".stripMargin
117+
checkAggs(query, 2, 0, 2, 0)
118+
}
111119
}
112120
}
113121

114122
test("do not replace hash aggregate if there is no group-by column") {
115123
withTempView("t1") {
116124
spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1")
117-
val query =
118-
"""
119-
|SELECT COUNT(key)
120-
|FROM t1
121-
""".stripMargin
122-
checkAggs(query, 2, 0, 2, 0)
125+
Seq("COUNT", "COLLECT_LIST").foreach { aggExpr =>
126+
val query =
127+
s"""
128+
|SELECT $aggExpr(key)
129+
|FROM t1
130+
""".stripMargin
131+
checkAggs(query, 2, 0, 2, 0)
132+
}
123133
}
124134
}
125135
}

0 commit comments

Comments
 (0)