Skip to content

[SPARK-37455][SQL] Replace hash with sort aggregate if child is already sorted #34702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val REPLACE_HASH_WITH_SORT_AGG_ENABLED = buildConf("spark.sql.execution.replaceHashWithSortAgg")
.internal()
.doc("Whether to replace hash aggregate node with sort aggregate based on children's ordering")
.version("3.3.0")
.booleanConf
.createWithDefault(false)

val STATE_STORE_PROVIDER_CLASS =
buildConf("spark.sql.streaming.stateStore.providerClass")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,9 @@ object QueryExecution {
PlanSubqueries(sparkSession),
RemoveRedundantProjects,
EnsureRequirements(),
// `ReplaceHashWithSortAgg` needs to be added after `EnsureRequirements` to guarantee the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it because the planner is top-down so we don't know the child ordering during planning? Then we have to add a new rule to change the agg algorithm in a post-hoc way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is. If we change our planning to bottom-up and propagate each node output ordering info during planning, then we can run this rule during planning. For now, we have to add it after EnsureRequirements.

// sort order of each node is checked to be valid.
ReplaceHashWithSortAgg,
// `RemoveRedundantSorts` needs to be added after `EnsureRequirements` to guarantee the same
// number of partitions when instantiating PartitioningCollection.
RemoveRedundantSorts,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Final, Partial}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.internal.SQLConf

/**
* Replace [[HashAggregateExec]] with [[SortAggregateExec]] in the spark plan if:
*
* 1. The plan is a pair of partial and final [[HashAggregateExec]], and the child of partial
* aggregate satisfies the sort order of corresponding [[SortAggregateExec]].
* or
* 2. The plan is a [[HashAggregateExec]], and the child satisfies the sort order of
* corresponding [[SortAggregateExec]].
*
* Examples:
* 1. aggregate after join:
*
* HashAggregate(t1.i, SUM, final)
* | SortAggregate(t1.i, SUM, complete)
* HashAggregate(t1.i, SUM, partial) => |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like an orthogonal optimization: we can merge adjacent partial and final aggregates (no shuffle between them) into one complete aggregate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think we can add a rule later to optimize it. I vaguely remember someone proposed this in OSS before but seems impact is not high.

* | SortMergeJoin(t1.i = t2.j)
* SortMergeJoin(t1.i = t2.j)
*
* 2. aggregate after sort:
*
* HashAggregate(t1.i, SUM, partial) SortAggregate(t1.i, SUM, partial)
* | => |
* Sort(t1.i) Sort(t1.i)
*
* [[HashAggregateExec]] can be replaced when its child satisfies the sort order of
* corresponding [[SortAggregateExec]]. [[SortAggregateExec]] is faster in the sense that
* it does not have hashing overhead of [[HashAggregateExec]].
*/
object ReplaceHashWithSortAgg extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
if (!conf.getConf(SQLConf.REPLACE_HASH_WITH_SORT_AGG_ENABLED)) {
plan
} else {
replaceHashAgg(plan)
}
}

/**
* Replace [[HashAggregateExec]] with [[SortAggregateExec]].
*/
private def replaceHashAgg(plan: SparkPlan): SparkPlan = {
plan.transformDown {
case hashAgg: HashAggregateExec if hashAgg.groupingExpressions.nonEmpty =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, shall we handle ObjectHashAggregateExec as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - yeah I agree. Don't see a problem why we cannot do it. Created https://issues.apache.org/jira/browse/SPARK-37557 for followup. Will do it shortly, thanks.

val sortAgg = hashAgg.toSortAggregate
hashAgg.child match {
case partialAgg: HashAggregateExec if isPartialAgg(partialAgg, hashAgg) =>
if (SortOrder.orderingSatisfies(
partialAgg.child.outputOrdering, sortAgg.requiredChildOrdering.head)) {
sortAgg.copy(
aggregateExpressions = sortAgg.aggregateExpressions.map(_.copy(mode = Complete)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it always right? I think we also need to check the output partitioning to see if we can eliminate the partial agg.

An example is df.sortWithinPartitions. It does not cluster the data, just sort it within each partition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - I don't think we need to check output partitioning, as we are matching a pair of final and partial hash agg, without shuffle in between:

  HashAggregate(final)
          |                                       SortAggregate(complete)
HashAggregate(partial)             =>                    |
          |                                            child
        child 

So child must already have proper output partitioning for SortAggregate, o.w. it cannot satisfy original HashAggregate(final)'s required distribution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok, if there is a shuffle in the middle, we can't optimize? This looks quite limited, as having a shuffle in the middle is very common.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there is a shuffle in the middle, we can't optimize?

We can, and the rule here also does pattern matching for single HashAggregate below. I added a unit test case in ReplaceHashWithSortAggSuite.scala to demonstrate replacing partial aggregate - "replace partial hash aggregate with sort aggregate". But I think it would be rare to be able to replace final aggregate (though this rule also covers it), as final aggregate is almostly always immediately after a shuffle, so there's no sort ordering before final aggregate.

Spark native shuffle does not guarantee any sort orders, for Cosco (a remote shuffle service we are running in-house), we support sorted shuffle, so final aggregate can also be possible to replace.

child = partialAgg.child)
} else {
hashAgg
}
case other =>
if (SortOrder.orderingSatisfies(
other.outputOrdering, sortAgg.requiredChildOrdering.head)) {
sortAgg
} else {
hashAgg
}
}
case other => other
}
}

/**
* Check if `partialAgg` to be partial aggregate of `finalAgg`.
*/
private def isPartialAgg(partialAgg: HashAggregateExec, finalAgg: HashAggregateExec): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like reverse enginering the AggUtils. Could we just link the partial and final agg when they are constructed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tanelk - yeah I agree this is mostly reverse engineering and we can do a better job here. I tried link partial and final agg in AggUtils and check linked physical plan to be same or not. This does not quite work due to we are doing top-down planning, and the linked partial agg not being same as planned partial agg (having PlanLater operator in linked partial agg).

I found a more elegant way to do it, by checking the linked logical plan of both aggs to be same. Updated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @cloud-fan for review, thanks.

if (partialAgg.aggregateExpressions.forall(_.mode == Partial) &&
finalAgg.aggregateExpressions.forall(_.mode == Final)) {
(finalAgg.logicalLink, partialAgg.logicalLink) match {
case (Some(agg1), Some(agg2)) => agg1.sameResult(agg2)
case _ => false
}
} else {
false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ case class AdaptiveSparkPlanExec(
Seq(
RemoveRedundantProjects,
ensureRequirements,
ReplaceHashWithSortAgg,
RemoveRedundantSorts,
DisableUnnecessaryBucketedScan,
OptimizeSkewedJoin(ensureRequirements)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,15 @@ case class HashAggregateExec(
}
}

/**
* The corresponding [[SortAggregateExec]] to get same result as this node.
*/
def toSortAggregate: SortAggregateExec = {
SortAggregateExec(
requiredChildDistributionExpressions, groupingExpressions, aggregateExpressions,
aggregateAttributes, initialInputBufferOffset, resultExpressions, child)
}

override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExec =
copy(child = newChild)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession

abstract class ReplaceHashWithSortAggSuiteBase
extends QueryTest
with SharedSparkSession
with AdaptiveSparkPlanHelper {

private def checkNumAggs(df: DataFrame, hashAggCount: Int, sortAggCount: Int): Unit = {
val plan = df.queryExecution.executedPlan
assert(collectWithSubqueries(plan) { case s: HashAggregateExec => s }.length == hashAggCount)
assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount)
}

private def checkAggs(
query: String,
enabledHashAggCount: Int,
enabledSortAggCount: Int,
disabledHashAggCount: Int,
disabledSortAggCount: Int): Unit = {
withSQLConf(SQLConf.REPLACE_HASH_WITH_SORT_AGG_ENABLED.key -> "true") {
val df = sql(query)
checkNumAggs(df, enabledHashAggCount, enabledSortAggCount)
val result = df.collect()
withSQLConf(SQLConf.REPLACE_HASH_WITH_SORT_AGG_ENABLED.key -> "false") {
val df = sql(query)
checkNumAggs(df, disabledHashAggCount, disabledSortAggCount)
checkAnswer(df, result)
}
}
}

test("replace partial hash aggregate with sort aggregate") {
withTempView("t") {
spark.range(100).selectExpr("id as key").repartition(10).createOrReplaceTempView("t")
val query =
"""
|SELECT key, FIRST(key)
|FROM
|(
| SELECT key
| FROM t
| WHERE key > 10
| SORT BY key
|)
|GROUP BY key
""".stripMargin
checkAggs(query, 1, 1, 2, 0)
}
}

test("replace partial and final hash aggregate together with sort aggregate") {
withTempView("t1", "t2") {
spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1")
spark.range(50).selectExpr("id as key").createOrReplaceTempView("t2")
val query =
"""
|SELECT key, COUNT(key)
|FROM
|(
| SELECT /*+ SHUFFLE_MERGE(t1) */ t1.key AS key
| FROM t1
| JOIN t2
| ON t1.key = t2.key
|)
|GROUP BY key
""".stripMargin
checkAggs(query, 0, 1, 2, 0)
}
}

test("do not replace hash aggregate if child does not have sort order") {
withTempView("t1", "t2") {
spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1")
spark.range(50).selectExpr("id as key").createOrReplaceTempView("t2")
val query =
"""
|SELECT key, COUNT(key)
|FROM
|(
| SELECT /*+ BROADCAST(t1) */ t1.key AS key
| FROM t1
| JOIN t2
| ON t1.key = t2.key
|)
|GROUP BY key
""".stripMargin
checkAggs(query, 2, 0, 2, 0)
}
}

test("do not replace hash aggregate if there is no group-by column") {
withTempView("t1") {
spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1")
val query =
"""
|SELECT COUNT(key)
|FROM t1
""".stripMargin
checkAggs(query, 2, 0, 2, 0)
}
}
}

class ReplaceHashWithSortAggSuite extends ReplaceHashWithSortAggSuiteBase
with DisableAdaptiveExecutionSuite

class ReplaceHashWithSortAggSuiteAE extends ReplaceHashWithSortAggSuiteBase
with EnableAdaptiveExecutionSuite