Skip to content

Commit dface2a

Browse files
committed
Preserve shuffled hash join build side partitioning
1 parent db47c6e commit dface2a

File tree

3 files changed

+31
-1
lines changed

3 files changed

+31
-1
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ trait HashJoin extends BaseJoinExec {
206206
existenceJoin(streamedIter, hashed)
207207
case x =>
208208
throw new IllegalArgumentException(
209-
s"BroadcastHashJoin should not take $x as the JoinType")
209+
s"HashJoin should not take $x as the JoinType")
210210
}
211211

212212
val resultProj = createResultProjection

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ case class ShuffledHashJoinExec(
4747
"buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
4848
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))
4949

50+
override def outputPartitioning: Partitioning = joinType match {
51+
case _: InnerLike =>
52+
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
53+
// For left and right outer joins, the output is partitioned by the streamed input's join keys.
54+
case LeftOuter => left.outputPartitioning
55+
case RightOuter => right.outputPartitioning
56+
case LeftExistence(_) => left.outputPartitioning
57+
case x =>
58+
throw new IllegalArgumentException(
59+
s"${getClass.getSimpleName} should not take $x as the JoinType")
60+
}
61+
5062
override def requiredChildDistribution: Seq[Distribution] =
5163
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
5264

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrd
3131
import org.apache.spark.sql.catalyst.plans.logical.Filter
3232
import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, SortExec, SparkPlan}
3333
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
34+
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
3435
import org.apache.spark.sql.execution.joins._
3536
import org.apache.spark.sql.execution.python.BatchEvalPythonExec
3637
import org.apache.spark.sql.internal.SQLConf
@@ -1086,4 +1087,21 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
10861087
assert(df2.join(df1, "id").collect().isEmpty)
10871088
}
10881089
}
1090+
1091+
test("SPARK-32330: Preserve shuffled hash join build side partitioning") {
1092+
withSQLConf(
1093+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50",
1094+
SQLConf.SHUFFLE_PARTITIONS.key -> "2",
1095+
SQLConf.PREFER_SORTMERGEJOIN.key -> "false") {
1096+
val df1 = spark.range(10).select($"id".as("k1"))
1097+
val df2 = spark.range(30).select($"id".as("k2"))
1098+
Seq("inner", "cross").foreach(joinType => {
1099+
val plan = df1.join(df2, $"k1" === $"k2", joinType).groupBy($"k1").count()
1100+
.queryExecution.executedPlan
1101+
assert(plan.collect { case _: ShuffledHashJoinExec => true }.size === 1)
1102+
// No extra shuffle before aggregate
1103+
assert(plan.collect { case _: ShuffleExchangeExec => true }.size === 2)
1104+
})
1105+
}
1106+
}
10891107
}

0 commit comments

Comments
 (0)