Skip to content

Commit fe07521

Browse files
c21cloud-fan
authored andcommitted
[SPARK-32330][SQL] Preserve shuffled hash join build side partitioning
### What changes were proposed in this pull request? Currently `ShuffledHashJoin.outputPartitioning` inherits from `HashJoin.outputPartitioning`, which only preserves stream side partitioning (`HashJoin.scala`): ``` override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning ``` This loses build side partitioning information, and causes extra shuffle if there's another join / group-by after this join. Example: ``` withSQLConf( SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50", SQLConf.SHUFFLE_PARTITIONS.key -> "2", SQLConf.PREFER_SORTMERGEJOIN.key -> "false") { val df1 = spark.range(10).select($"id".as("k1")) val df2 = spark.range(30).select($"id".as("k2")) Seq("inner", "cross").foreach(joinType => { val plan = df1.join(df2, $"k1" === $"k2", joinType).groupBy($"k1").count() .queryExecution.executedPlan assert(plan.collect { case _: ShuffledHashJoinExec => true }.size === 1) // No extra shuffle before aggregate assert(plan.collect { case _: ShuffleExchangeExec => true }.size === 2) }) } ``` Current physical plan (having an extra shuffle on `k1` before aggregate) ``` *(4) HashAggregate(keys=[k1#220L], functions=[count(1)], output=[k1#220L, count#235L]) +- Exchange hashpartitioning(k1#220L, 2), true, [id=#117] +- *(3) HashAggregate(keys=[k1#220L], functions=[partial_count(1)], output=[k1#220L, count#239L]) +- *(3) Project [k1#220L] +- ShuffledHashJoin [k1#220L], [k2#224L], Inner, BuildLeft :- Exchange hashpartitioning(k1#220L, 2), true, [id=#109] : +- *(1) Project [id#218L AS k1#220L] : +- *(1) Range (0, 10, step=1, splits=2) +- Exchange hashpartitioning(k2#224L, 2), true, [id=#111] +- *(2) Project [id#222L AS k2#224L] +- *(2) Range (0, 30, step=1, splits=2) ``` Ideal physical plan (no shuffle on `k1` before aggregate) ``` *(3) HashAggregate(keys=[k1#220L], functions=[count(1)], output=[k1#220L, count#235L]) +- *(3) HashAggregate(keys=[k1#220L], functions=[partial_count(1)], output=[k1#220L, count#239L]) +- *(3) Project [k1#220L] +- ShuffledHashJoin [k1#220L], [k2#224L], Inner, BuildLeft :- Exchange hashpartitioning(k1#220L, 2), true, [id=#107] : +- *(1) Project [id#218L AS k1#220L] : +- *(1) Range (0, 10, step=1, splits=2) +- Exchange hashpartitioning(k2#224L, 2), true, [id=#109] +- *(2) Project [id#222L AS k2#224L] +- *(2) Range (0, 30, step=1, splits=2) ``` This can be fixed by overriding `outputPartitioning` method in `ShuffledHashJoinExec`, similar to `SortMergeJoinExec`. In addition, also fix one typo in `HashJoin`, as that code path is shared between broadcast hash join and shuffled hash join. ### Why are the changes needed? To avoid shuffle (for queries having multiple joins or group-by), for saving CPU and IO. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added unit test in `JoinSuite`. Closes #29130 from c21/shj. Authored-by: Cheng Su <chengsu@fb.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent e0ecb66 commit fe07521

File tree

5 files changed

+66
-19
lines changed

5 files changed

+66
-19
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ trait HashJoin extends BaseJoinExec {
215215
existenceJoin(streamedIter, hashed)
216216
case x =>
217217
throw new IllegalArgumentException(
218-
s"BroadcastHashJoin should not take $x as the JoinType")
218+
s"HashJoin should not take $x as the JoinType")
219219
}
220220

221221
val resultProj = createResultProjection

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,14 @@ case class ShuffledHashJoinExec(
4040
condition: Option[Expression],
4141
left: SparkPlan,
4242
right: SparkPlan)
43-
extends HashJoin {
43+
extends HashJoin with ShuffledJoin {
4444

4545
override lazy val metrics = Map(
4646
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
4747
"buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
4848
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))
4949

50-
override def requiredChildDistribution: Seq[Distribution] =
51-
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
50+
override def outputPartitioning: Partitioning = super[ShuffledJoin].outputPartitioning
5251

5352
private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
5453
val buildDataSize = longMetric("buildDataSize")
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.joins
19+
20+
import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, LeftExistence, LeftOuter, RightOuter}
21+
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution, Partitioning, PartitioningCollection, UnknownPartitioning}
22+
23+
/**
24+
* Holds common logic for join operators by shuffling two child relations
25+
* using the join keys.
26+
*/
27+
trait ShuffledJoin extends BaseJoinExec {
28+
override def requiredChildDistribution: Seq[Distribution] = {
29+
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
30+
}
31+
32+
override def outputPartitioning: Partitioning = joinType match {
33+
case _: InnerLike =>
34+
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
35+
case LeftOuter => left.outputPartitioning
36+
case RightOuter => right.outputPartitioning
37+
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
38+
case LeftExistence(_) => left.outputPartitioning
39+
case x =>
40+
throw new IllegalArgumentException(
41+
s"ShuffledJoin should not take $x as the JoinType")
42+
}
43+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ case class SortMergeJoinExec(
4141
condition: Option[Expression],
4242
left: SparkPlan,
4343
right: SparkPlan,
44-
isSkewJoin: Boolean = false) extends BaseJoinExec with CodegenSupport {
44+
isSkewJoin: Boolean = false) extends ShuffledJoin with CodegenSupport {
4545

4646
override lazy val metrics = Map(
4747
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -72,26 +72,13 @@ case class SortMergeJoinExec(
7272
}
7373
}
7474

75-
override def outputPartitioning: Partitioning = joinType match {
76-
case _: InnerLike =>
77-
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
78-
// For left and right outer joins, the output is partitioned by the streamed input's join keys.
79-
case LeftOuter => left.outputPartitioning
80-
case RightOuter => right.outputPartitioning
81-
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
82-
case LeftExistence(_) => left.outputPartitioning
83-
case x =>
84-
throw new IllegalArgumentException(
85-
s"${getClass.getSimpleName} should not take $x as the JoinType")
86-
}
87-
8875
override def requiredChildDistribution: Seq[Distribution] = {
8976
if (isSkewJoin) {
9077
// We re-arrange the shuffle partitions to deal with skew join, and the new children
9178
// partitioning doesn't satisfy `HashClusteredDistribution`.
9279
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
9380
} else {
94-
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
81+
super.requiredChildDistribution
9582
}
9683
}
9784

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrd
3131
import org.apache.spark.sql.catalyst.plans.logical.Filter
3232
import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, SortExec, SparkPlan}
3333
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
34+
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
3435
import org.apache.spark.sql.execution.joins._
3536
import org.apache.spark.sql.execution.python.BatchEvalPythonExec
3637
import org.apache.spark.sql.internal.SQLConf
@@ -1086,4 +1087,21 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
10861087
assert(df2.join(df1, "id").collect().isEmpty)
10871088
}
10881089
}
1090+
1091+
test("SPARK-32330: Preserve shuffled hash join build side partitioning") {
1092+
withSQLConf(
1093+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50",
1094+
SQLConf.SHUFFLE_PARTITIONS.key -> "2",
1095+
SQLConf.PREFER_SORTMERGEJOIN.key -> "false") {
1096+
val df1 = spark.range(10).select($"id".as("k1"))
1097+
val df2 = spark.range(30).select($"id".as("k2"))
1098+
Seq("inner", "cross").foreach(joinType => {
1099+
val plan = df1.join(df2, $"k1" === $"k2", joinType).groupBy($"k1").count()
1100+
.queryExecution.executedPlan
1101+
assert(plan.collect { case _: ShuffledHashJoinExec => true }.size === 1)
1102+
// No extra shuffle before aggregate
1103+
assert(plan.collect { case _: ShuffleExchangeExec => true }.size === 2)
1104+
})
1105+
}
1106+
}
10891107
}

0 commit comments

Comments
 (0)