Skip to content

Commit 9a4f8c3

Browse files
wakunGitHub Enterprise
wakun
authored and
GitHub Enterprise
committed
[CARMEL-6174][FOLLOWUP] Change prefer shuffled hash join condition (#1099)
* [CARMEL-6174][FOLLOWUP] Change prefer shuffled hash join condition * Select SHJ the max partition size than ADAPTIVE_SHUFFLE_HASH_JOIN_ADVISORY_STREAM_PARTITION_SIZE
1 parent 41e5524 commit 9a4f8c3

File tree

3 files changed

+106
-40
lines changed

3 files changed

+106
-40
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,14 @@ object SQLConf {
736736
.bytesConf(ByteUnit.BYTE)
737737
.createWithDefault(0L)
738738

739+
val ADAPTIVE_SHUFFLE_HASH_JOIN_ADVISORY_STREAM_PARTITION_SIZE =
740+
buildConf("spark.sql.adaptive.shuffledHashJoinAdvisoryStreamPartitionSize")
741+
.doc(s"If the median partition size is larger than this config, join selection prefer to " +
742+
s"use shuffled hash join.")
743+
.version("3.2.0")
744+
.bytesConf(ByteUnit.BYTE)
745+
.createWithDefault(0L)
746+
739747
val SUBEXPRESSION_ELIMINATION_ENABLED =
740748
buildConf("spark.sql.subexpressionElimination.enabled")
741749
.internal()

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DynamicJoinSelection.scala

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,37 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
2323
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftOuter, RightOuter}
2424
import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Join, JoinStrategyHint, LogicalPlan, NO_BROADCAST_HASH, PREFER_SHUFFLE_HASH, SHUFFLE_HASH}
2525
import org.apache.spark.sql.catalyst.rules.Rule
26-
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
27-
import org.apache.spark.sql.execution.{CoalescedPartitionSpec, SparkPlan}
26+
import org.apache.spark.sql.execution.CoalescedPartitionSpec
2827
import org.apache.spark.sql.execution.adaptive.OptimizeSkewedJoin.getSkewThreshold
2928
import org.apache.spark.sql.internal.SQLConf
3029
import org.apache.spark.util.Utils
3130

3231
/**
3332
* This optimization rule includes three join selection:
34-
* 1. detects a join child that has a high ratio of empty partitions and adds a
33+
* 1. Do not add any until all the children are materialized and don't need additional shuffle.
34+
* 1.1 Won't select any join strategy for the following query as it need additional shuffle
35+
* after all ShuffleQueryStageExec are materialized
36+
* SortMergeJoin
37+
* :- ShuffleQueryStageExec (hashpartitioning(ID#1, 10000))
38+
* +- SortMergeJoin
39+
* :- ShuffleQueryStageExec (hashpartitioning(ID#2, 500))
40+
* +- ShuffleQueryStageExec (hashpartitioning(ID#3, 500))
41+
*
42+
* 1.2 Won't select any join strategy if the other side contains bucket table.
43+
* SortMergeJoin
44+
* :- ShuffleQueryStageExec (hashpartitioning(ID#1, 10000))
45+
* +- BucketTableScan
46+
*
47+
* 2. detects a join child that has a high ratio of empty partitions and adds a
3548
* NO_BROADCAST_HASH hint to avoid it being broadcast, as shuffle join is faster in this case:
3649
* many tasks complete immediately since one join side is empty.
37-
* 2. detects a join child that every partition size is less than local map threshold and adds a
50+
* 3. detects a join child that every partition size is less than local map threshold and adds a
3851
* PREFER_SHUFFLE_HASH hint to encourage being shuffle hash join instead of sort merge join.
39-
* 3. if a join satisfies both NO_BROADCAST_HASH and PREFER_SHUFFLE_HASH,
52+
* 4. if a join satisfies both NO_BROADCAST_HASH and PREFER_SHUFFLE_HASH,
4053
* then add a SHUFFLE_HASH hint.
4154
*/
4255
object DynamicJoinSelection extends Rule[LogicalPlan] with JoinSelectionHelper {
4356

44-
val USER_DEFINED_HINT_TAG = TreeNodeTag[Boolean]("USER_DEFINED_HINT")
45-
4657
private def hasManyEmptyPartitions(mapStats: MapOutputStatistics): Boolean = {
4758
val partitionCnt = mapStats.bytesByPartitionId.length
4859
val nonZeroCnt = mapStats.bytesByPartitionId.count(_ > 0)
@@ -65,22 +76,20 @@ object DynamicJoinSelection extends Rule[LogicalPlan] with JoinSelectionHelper {
6576
streamedStats: Seq[MapOutputStatistics]): Boolean = {
6677
val maxShuffledHashJoinLocalMapThreshold =
6778
conf.getConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD)
79+
val advisoryStreamPartitionSize =
80+
conf.getConf(SQLConf.ADAPTIVE_SHUFFLE_HASH_JOIN_ADVISORY_STREAM_PARTITION_SIZE)
6881
// If the join is skew, since CARMEL will not handle SHJ skew join, and we are not sure SHJ
6982
// will be faster better SMJ for the left skew join patten, so do not convert to SHJ if any
7083
// join side is skew.
71-
if (maxShuffledHashJoinLocalMapThreshold <= 0 || streamedStats.exists(isSkew(_))) {
84+
if (maxShuffledHashJoinLocalMapThreshold <= 0 || advisoryStreamPartitionSize <= 0 ||
85+
streamedStats.exists(isSkew(_))) {
7286
return false
7387
}
74-
val partitionSpecs = ShufflePartitionsUtil.coalescePartitions(
75-
Array(mapStats) ++ streamedStats,
76-
advisoryTargetSize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES),
77-
minNumPartitions = 0)
78-
partitionSpecs.nonEmpty &&
79-
partitionSpecs.forall(_.isInstanceOf[CoalescedPartitionSpec]) &&
80-
partitionSpecs.collect {
81-
case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) =>
82-
mapStats.bytesByPartitionId.slice(startReducerIndex, endReducerIndex).sum
83-
}.forall(_ <= maxShuffledHashJoinLocalMapThreshold)
88+
89+
mapStats.bytesByPartitionId.forall(_ <= maxShuffledHashJoinLocalMapThreshold) &&
90+
streamedStats.filter(_.bytesByPartitionId.length > 0).exists { stats =>
91+
Utils.median(stats.bytesByPartitionId, false) > advisoryStreamPartitionSize
92+
}
8493
}
8594

8695
private def selectJoinStrategy(
@@ -123,16 +132,29 @@ object DynamicJoinSelection extends Rule[LogicalPlan] with JoinSelectionHelper {
123132
stage.computeStats().exists(_.rowCount.exists(_.toLong >= conf.broadcastMaxRowNum))
124133
val adjustDemoteBroadcastHash = rowNumberExceeded || demoteBroadcastHash
125134

126-
def collectShuffleStats(plan: LogicalPlan): Seq[MapOutputStatistics] = plan match {
135+
var bucketedPlan = false
136+
def collectShuffleStats(plan: LogicalPlan): Seq[Option[MapOutputStatistics]] = plan match {
127137
case LogicalQueryStage(_, streamedStage: ShuffleQueryStageExec)
128138
if streamedStage.isMaterialized && streamedStage.mapStats.isDefined =>
129-
Seq(streamedStage.mapStats.get)
130-
case _ => plan.children.flatMap(collectShuffleStats)
139+
Seq(streamedStage.mapStats)
140+
case LogicalQueryStage(_, _: ShuffleQueryStageExec) => Seq(None)
141+
case _ if plan.children.nonEmpty => plan.children.flatMap(collectShuffleStats)
142+
case _ =>
143+
bucketedPlan = true
144+
Seq()
131145
}
132-
val preferShuffleHash =
133-
preferShuffledHashJoin(stage.mapStats.get, collectShuffleStats(streamedPlan))
134146

135-
logInfo(s"canBroadcastPlan = $canBroadcastPlan, rowNumberExceeded = " +
147+
val streamedStats = collectShuffleStats(streamedPlan)
148+
val allStats = Array(stage.mapStats) ++ streamedStats
149+
150+
val shuffleMaterialized =
151+
allStats.forall(_.isDefined) &&
152+
allStats.map(_.get.bytesByPartitionId.length).distinct.length == 1
153+
val preferShuffleHash = !bucketedPlan && shuffleMaterialized &&
154+
preferShuffledHashJoin(stage.mapStats.get, streamedStats.map(_.get))
155+
156+
logInfo(s"isLeft = $isLeft, shuffleMaterialized = $shuffleMaterialized, " +
157+
s"canBroadcastPlan = $canBroadcastPlan, rowNumberExceeded = " +
136158
s"$rowNumberExceeded, adjustDemoteBroadcastHash = $adjustDemoteBroadcastHash, " +
137159
s"preferShuffleHash = $preferShuffleHash")
138160
if (adjustDemoteBroadcastHash && preferShuffleHash) {
@@ -150,24 +172,18 @@ object DynamicJoinSelection extends Rule[LogicalPlan] with JoinSelectionHelper {
150172
}
151173

152174
def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown {
153-
case j @ ExtractEquiJoinKeys(_, _, _, _, left, right, hint) =>
154-
if (left.getTagValue(USER_DEFINED_HINT_TAG).isEmpty) {
155-
left.setTagValue(USER_DEFINED_HINT_TAG, hint.leftHint.exists(_.strategy.isDefined))
156-
}
157-
if (right.getTagValue(USER_DEFINED_HINT_TAG).isEmpty) {
158-
right.setTagValue(USER_DEFINED_HINT_TAG, hint.rightHint.exists(_.strategy.isDefined))
159-
}
175+
case j @ ExtractEquiJoinKeys(_, _, _, _, _, _, hint) =>
160176
var newHint = hint
161-
if (!left.getTagValue(USER_DEFINED_HINT_TAG).getOrElse(false)) {
177+
if (!hint.leftHint.exists(_.strategy.isDefined) ||
178+
hint.leftHint.get.strategy.contains(NO_BROADCAST_HASH)) {
162179
selectJoinStrategy(j, true).foreach { strategy =>
163-
logInfo(s"Set left side join strategy: $strategy")
164180
newHint = newHint.copy(leftHint =
165181
Some(hint.leftHint.getOrElse(HintInfo()).copy(strategy = Some(strategy))))
166182
}
167183
}
168-
if (!right.getTagValue(USER_DEFINED_HINT_TAG).getOrElse(false)) {
184+
if (!hint.rightHint.exists(_.strategy.isDefined) ||
185+
hint.rightHint.get.strategy.contains(NO_BROADCAST_HASH)) {
169186
selectJoinStrategy(j, false).foreach { strategy =>
170-
logInfo(s"Set right side join strategy: $strategy")
171187
newHint = newHint.copy(rightHint =
172188
Some(hint.rightHint.getOrElse(HintInfo()).copy(strategy = Some(strategy))))
173189
}

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2496,22 +2496,64 @@ class AdaptiveQueryExecSuite
24962496
(1 to 30).map(i => TestData(i, i.toString)), 5)
24972497
.toDF("c1", "c2").createOrReplaceTempView("t2")
24982498

2499-
// left partition size: [926, 729, 731] after coalesce : [926, 1460]
2500-
// right partition size: [416, 258, 252] after coalesce : [416, 510]
2499+
// left partition size: [926, 729, 731]
2500+
// right partition size: [416, 258, 252]
25012501
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3",
25022502
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "450",
2503-
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "2000",
2503+
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "300",
2504+
SQLConf.ADAPTIVE_SHUFFLE_HASH_JOIN_ADVISORY_STREAM_PARTITION_SIZE.key -> "700",
25042505
SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
25052506
// check default value ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD = 0
25062507
checkJoinStrategy(false)
2507-
withSQLConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "500") {
2508+
withSQLConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "300") {
25082509
checkJoinStrategy(false)
25092510
}
2510-
withSQLConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "800") {
2511+
withSQLConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "500") {
25112512
checkJoinStrategy(true)
25122513
}
25132514
}
25142515
}
25152516
}
25162517
}
2518+
2519+
test("CARMEL-6174: Won't use SHJ for Bucket") {
2520+
withTempView("t1", "t2") {
2521+
def checkJoinStrategy(shouldShuffleHashJoin: Boolean): Unit = {
2522+
val (origin1, adaptive1) = runAdaptiveAndVerifyResult(
2523+
"SELECT t1.c1, t2.c1 FROM t1 JOIN t2 ON t1.c1 = t2.c1")
2524+
assert(findTopLevelSortMergeJoin(origin1).size === 1)
2525+
if (shouldShuffleHashJoin) {
2526+
val shj = findTopLevelShuffledHashJoin(adaptive1)
2527+
assert(shj.size === 1)
2528+
assert(shj.head.buildSide == BuildRight)
2529+
} else {
2530+
assert(findTopLevelSortMergeJoin(adaptive1).size === 1)
2531+
}
2532+
2533+
withSQLConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "0") {
2534+
// respect user specified join hint
2535+
val (origin2, adaptive2) = runAdaptiveAndVerifyResult(
2536+
"SELECT /*+ MERGE(t1) */ t1.c1, t2.c1 FROM t1 JOIN t2 ON t1.c1 = t2.c1")
2537+
assert(findTopLevelSortMergeJoin(origin2).size === 1)
2538+
assert(findTopLevelSortMergeJoin(adaptive2).size === 1)
2539+
}
2540+
}
2541+
2542+
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
2543+
spark.range(1, 100, 10).map(i => (i, i.toString)).toDF("c1", "c2")
2544+
.write.format("parquet").saveAsTable("t1")
2545+
spark.range(1, 30, 5).map(i => (i, i.toString)).toDF("c1", "c2")
2546+
.write.format("parquet").bucketBy(2, "c1").sortBy("c1").saveAsTable("t2")
2547+
2548+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3",
2549+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0",
2550+
SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "1024000",
2551+
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "200",
2552+
SQLConf.ADAPTIVE_SHUFFLE_HASH_JOIN_ADVISORY_STREAM_PARTITION_SIZE.key -> "500",
2553+
SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
2554+
checkJoinStrategy(false)
2555+
}
2556+
}
2557+
}
2558+
}
25172559
}

0 commit comments

Comments
 (0)