-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-37455][SQL] Replace hash with sort aggregate if child is already sorted #34702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6448864
a683137
e8609fd
cff1424
8ce7d27
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.execution | ||
|
||
import org.apache.spark.sql.catalyst.expressions.SortOrder | ||
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Final, Partial} | ||
import org.apache.spark.sql.catalyst.rules.Rule | ||
import org.apache.spark.sql.execution.aggregate.HashAggregateExec | ||
import org.apache.spark.sql.internal.SQLConf | ||
|
||
/** | ||
* Replace [[HashAggregateExec]] with [[SortAggregateExec]] in the spark plan if: | ||
* | ||
* 1. The plan is a pair of partial and final [[HashAggregateExec]], and the child of partial | ||
* aggregate satisfies the sort order of corresponding [[SortAggregateExec]]. | ||
* or | ||
* 2. The plan is a [[HashAggregateExec]], and the child satisfies the sort order of | ||
* corresponding [[SortAggregateExec]]. | ||
* | ||
* Examples: | ||
* 1. aggregate after join: | ||
* | ||
* HashAggregate(t1.i, SUM, final) | ||
* | SortAggregate(t1.i, SUM, complete) | ||
* HashAggregate(t1.i, SUM, partial) => | | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like an orthogonal optimization: we can merge adjacent partial and final aggregates (no shuffle between them) into one complete aggregate. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I think we can add a rule later to optimize it. I vaguely remember someone proposed this in OSS before but seems impact is not high. |
||
* | SortMergeJoin(t1.i = t2.j) | ||
* SortMergeJoin(t1.i = t2.j) | ||
* | ||
* 2. aggregate after sort: | ||
* | ||
* HashAggregate(t1.i, SUM, partial) SortAggregate(t1.i, SUM, partial) | ||
* | => | | ||
* Sort(t1.i) Sort(t1.i) | ||
* | ||
* [[HashAggregateExec]] can be replaced when its child satisfies the sort order of | ||
* corresponding [[SortAggregateExec]]. [[SortAggregateExec]] is faster in the sense that | ||
* it does not have hashing overhead of [[HashAggregateExec]]. | ||
*/ | ||
object ReplaceHashWithSortAgg extends Rule[SparkPlan] { | ||
def apply(plan: SparkPlan): SparkPlan = { | ||
if (!conf.getConf(SQLConf.REPLACE_HASH_WITH_SORT_AGG_ENABLED)) { | ||
plan | ||
} else { | ||
replaceHashAgg(plan) | ||
} | ||
} | ||
|
||
/** | ||
* Replace [[HashAggregateExec]] with [[SortAggregateExec]]. | ||
*/ | ||
private def replaceHashAgg(plan: SparkPlan): SparkPlan = { | ||
plan.transformDown { | ||
case hashAgg: HashAggregateExec if hashAgg.groupingExpressions.nonEmpty => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW, shall we handle There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cloud-fan - yeah I agree. Don't see a problem why we cannot do it. Created https://issues.apache.org/jira/browse/SPARK-37557 for followup. Will do it shortly, thanks. |
||
val sortAgg = hashAgg.toSortAggregate | ||
hashAgg.child match { | ||
case partialAgg: HashAggregateExec if isPartialAgg(partialAgg, hashAgg) => | ||
if (SortOrder.orderingSatisfies( | ||
partialAgg.child.outputOrdering, sortAgg.requiredChildOrdering.head)) { | ||
sortAgg.copy( | ||
aggregateExpressions = sortAgg.aggregateExpressions.map(_.copy(mode = Complete)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it always right? I think we also need to check the output partitioning to see if we can eliminate the partial agg. An example is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cloud-fan - I don't think we need to check output partitioning, as we are matching a pair of final and partial hash agg, without shuffle in between:
So There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah ok, if there is a shuffle in the middle, we can't optimize? This looks quite limited, as having a shuffle in the middle is very common. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We can, and the rule here also does pattern matching for single Spark native shuffle does not guarantee any sort orders, for Cosco (a remote shuffle service we are running in-house), we support sorted shuffle, so final aggregate can also be possible to replace. |
||
child = partialAgg.child) | ||
} else { | ||
hashAgg | ||
} | ||
case other => | ||
if (SortOrder.orderingSatisfies( | ||
other.outputOrdering, sortAgg.requiredChildOrdering.head)) { | ||
sortAgg | ||
} else { | ||
hashAgg | ||
} | ||
} | ||
case other => other | ||
} | ||
} | ||
|
||
/** | ||
* Check if `partialAgg` to be partial aggregate of `finalAgg`. | ||
*/ | ||
private def isPartialAgg(partialAgg: HashAggregateExec, finalAgg: HashAggregateExec): Boolean = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like reverse enginering the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tanelk - yeah I agree this is mostly reverse engineering and we can do a better job here. I tried link partial and final agg in I found a more elegant way to do it, by checking the linked logical plan of both aggs to be same. Updated. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @cloud-fan for review, thanks. |
||
if (partialAgg.aggregateExpressions.forall(_.mode == Partial) && | ||
finalAgg.aggregateExpressions.forall(_.mode == Final)) { | ||
(finalAgg.logicalLink, partialAgg.logicalLink) match { | ||
case (Some(agg1), Some(agg2)) => agg1.sameResult(agg2) | ||
case _ => false | ||
} | ||
} else { | ||
false | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.execution | ||
|
||
import org.apache.spark.sql.{DataFrame, QueryTest} | ||
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} | ||
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} | ||
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.test.SharedSparkSession | ||
|
||
abstract class ReplaceHashWithSortAggSuiteBase | ||
extends QueryTest | ||
with SharedSparkSession | ||
with AdaptiveSparkPlanHelper { | ||
|
||
private def checkNumAggs(df: DataFrame, hashAggCount: Int, sortAggCount: Int): Unit = { | ||
val plan = df.queryExecution.executedPlan | ||
assert(collectWithSubqueries(plan) { case s: HashAggregateExec => s }.length == hashAggCount) | ||
assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount) | ||
} | ||
|
||
private def checkAggs( | ||
query: String, | ||
enabledHashAggCount: Int, | ||
enabledSortAggCount: Int, | ||
disabledHashAggCount: Int, | ||
disabledSortAggCount: Int): Unit = { | ||
withSQLConf(SQLConf.REPLACE_HASH_WITH_SORT_AGG_ENABLED.key -> "true") { | ||
val df = sql(query) | ||
checkNumAggs(df, enabledHashAggCount, enabledSortAggCount) | ||
val result = df.collect() | ||
withSQLConf(SQLConf.REPLACE_HASH_WITH_SORT_AGG_ENABLED.key -> "false") { | ||
val df = sql(query) | ||
checkNumAggs(df, disabledHashAggCount, disabledSortAggCount) | ||
checkAnswer(df, result) | ||
} | ||
} | ||
} | ||
|
||
test("replace partial hash aggregate with sort aggregate") { | ||
withTempView("t") { | ||
spark.range(100).selectExpr("id as key").repartition(10).createOrReplaceTempView("t") | ||
val query = | ||
""" | ||
|SELECT key, FIRST(key) | ||
|FROM | ||
|( | ||
| SELECT key | ||
| FROM t | ||
| WHERE key > 10 | ||
| SORT BY key | ||
|) | ||
|GROUP BY key | ||
""".stripMargin | ||
checkAggs(query, 1, 1, 2, 0) | ||
} | ||
} | ||
|
||
test("replace partial and final hash aggregate together with sort aggregate") { | ||
withTempView("t1", "t2") { | ||
spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1") | ||
spark.range(50).selectExpr("id as key").createOrReplaceTempView("t2") | ||
val query = | ||
""" | ||
|SELECT key, COUNT(key) | ||
|FROM | ||
|( | ||
| SELECT /*+ SHUFFLE_MERGE(t1) */ t1.key AS key | ||
| FROM t1 | ||
| JOIN t2 | ||
| ON t1.key = t2.key | ||
|) | ||
|GROUP BY key | ||
""".stripMargin | ||
checkAggs(query, 0, 1, 2, 0) | ||
} | ||
} | ||
|
||
test("do not replace hash aggregate if child does not have sort order") { | ||
withTempView("t1", "t2") { | ||
spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1") | ||
spark.range(50).selectExpr("id as key").createOrReplaceTempView("t2") | ||
val query = | ||
""" | ||
|SELECT key, COUNT(key) | ||
|FROM | ||
|( | ||
| SELECT /*+ BROADCAST(t1) */ t1.key AS key | ||
| FROM t1 | ||
| JOIN t2 | ||
| ON t1.key = t2.key | ||
|) | ||
|GROUP BY key | ||
""".stripMargin | ||
checkAggs(query, 2, 0, 2, 0) | ||
} | ||
} | ||
|
||
test("do not replace hash aggregate if there is no group-by column") { | ||
withTempView("t1") { | ||
spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1") | ||
val query = | ||
""" | ||
|SELECT COUNT(key) | ||
|FROM t1 | ||
""".stripMargin | ||
checkAggs(query, 2, 0, 2, 0) | ||
} | ||
} | ||
} | ||
|
||
class ReplaceHashWithSortAggSuite extends ReplaceHashWithSortAggSuiteBase | ||
with DisableAdaptiveExecutionSuite | ||
|
||
class ReplaceHashWithSortAggSuiteAE extends ReplaceHashWithSortAggSuiteBase | ||
with EnableAdaptiveExecutionSuite |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it because the planner is top-down so we don't know the child ordering during planning? Then we have to add a new rule to change the agg algorithm in a post-hoc way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it is. If we change our planning to bottom-up and propagate each node output ordering info during planning, then we can run this rule during planning. For now, we have to add it after
EnsureRequirements
.