@@ -23,26 +23,37 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
23
23
import org .apache .spark .sql .catalyst .plans .{LeftAnti , LeftOuter , RightOuter }
24
24
import org .apache .spark .sql .catalyst .plans .logical .{HintInfo , Join , JoinStrategyHint , LogicalPlan , NO_BROADCAST_HASH , PREFER_SHUFFLE_HASH , SHUFFLE_HASH }
25
25
import org .apache .spark .sql .catalyst .rules .Rule
26
- import org .apache .spark .sql .catalyst .trees .TreeNodeTag
27
- import org .apache .spark .sql .execution .{CoalescedPartitionSpec , SparkPlan }
26
+ import org .apache .spark .sql .execution .CoalescedPartitionSpec
28
27
import org .apache .spark .sql .execution .adaptive .OptimizeSkewedJoin .getSkewThreshold
29
28
import org .apache .spark .sql .internal .SQLConf
30
29
import org .apache .spark .util .Utils
31
30
32
31
/**
33
32
* This optimization rule includes three join selection:
34
- * 1. detects a join child that has a high ratio of empty partitions and adds a
33
+ * 1. Do not add any until all the children are materialized and don't need additional shuffle.
34
+ * 1.1 Won't select any join strategy for the following query as it need additional shuffle
35
+ * after all ShuffleQueryStageExec are materialized
36
+ * SortMergeJoin
37
+ * :- ShuffleQueryStageExec (hashpartitioning(ID#1, 10000))
38
+ * +- SortMergeJoin
39
+ * :- ShuffleQueryStageExec (hashpartitioning(ID#2, 500))
40
+ * +- ShuffleQueryStageExec (hashpartitioning(ID#3, 500))
41
+ *
42
+ * 1.2 Won't select any join strategy if the other side contains bucket table.
43
+ * SortMergeJoin
44
+ * :- ShuffleQueryStageExec (hashpartitioning(ID#1, 10000))
45
+ * +- BucketTableScan
46
+ *
47
+ * 2. detects a join child that has a high ratio of empty partitions and adds a
35
48
* NO_BROADCAST_HASH hint to avoid it being broadcast, as shuffle join is faster in this case:
36
49
* many tasks complete immediately since one join side is empty.
37
- * 2 . detects a join child that every partition size is less than local map threshold and adds a
50
+ * 3 . detects a join child that every partition size is less than local map threshold and adds a
38
51
* PREFER_SHUFFLE_HASH hint to encourage being shuffle hash join instead of sort merge join.
39
- * 3 . if a join satisfies both NO_BROADCAST_HASH and PREFER_SHUFFLE_HASH,
52
+ * 4 . if a join satisfies both NO_BROADCAST_HASH and PREFER_SHUFFLE_HASH,
40
53
* then add a SHUFFLE_HASH hint.
41
54
*/
42
55
object DynamicJoinSelection extends Rule [LogicalPlan ] with JoinSelectionHelper {
43
56
44
- val USER_DEFINED_HINT_TAG = TreeNodeTag [Boolean ](" USER_DEFINED_HINT" )
45
-
46
57
private def hasManyEmptyPartitions (mapStats : MapOutputStatistics ): Boolean = {
47
58
val partitionCnt = mapStats.bytesByPartitionId.length
48
59
val nonZeroCnt = mapStats.bytesByPartitionId.count(_ > 0 )
@@ -65,22 +76,20 @@ object DynamicJoinSelection extends Rule[LogicalPlan] with JoinSelectionHelper {
65
76
streamedStats : Seq [MapOutputStatistics ]): Boolean = {
66
77
val maxShuffledHashJoinLocalMapThreshold =
67
78
conf.getConf(SQLConf .ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD )
79
+ val advisoryStreamPartitionSize =
80
+ conf.getConf(SQLConf .ADAPTIVE_SHUFFLE_HASH_JOIN_ADVISORY_STREAM_PARTITION_SIZE )
68
81
// If the join is skew, since CARMEL will not handle SHJ skew join, and we are not sure SHJ
69
82
// will be faster better SMJ for the left skew join patten, so do not convert to SHJ if any
70
83
// join side is skew.
71
- if (maxShuffledHashJoinLocalMapThreshold <= 0 || streamedStats.exists(isSkew(_))) {
84
+ if (maxShuffledHashJoinLocalMapThreshold <= 0 || advisoryStreamPartitionSize <= 0 ||
85
+ streamedStats.exists(isSkew(_))) {
72
86
return false
73
87
}
74
- val partitionSpecs = ShufflePartitionsUtil .coalescePartitions(
75
- Array (mapStats) ++ streamedStats,
76
- advisoryTargetSize = conf.getConf(SQLConf .ADVISORY_PARTITION_SIZE_IN_BYTES ),
77
- minNumPartitions = 0 )
78
- partitionSpecs.nonEmpty &&
79
- partitionSpecs.forall(_.isInstanceOf [CoalescedPartitionSpec ]) &&
80
- partitionSpecs.collect {
81
- case CoalescedPartitionSpec (startReducerIndex, endReducerIndex) =>
82
- mapStats.bytesByPartitionId.slice(startReducerIndex, endReducerIndex).sum
83
- }.forall(_ <= maxShuffledHashJoinLocalMapThreshold)
88
+
89
+ mapStats.bytesByPartitionId.forall(_ <= maxShuffledHashJoinLocalMapThreshold) &&
90
+ streamedStats.filter(_.bytesByPartitionId.length > 0 ).exists { stats =>
91
+ Utils .median(stats.bytesByPartitionId, false ) > advisoryStreamPartitionSize
92
+ }
84
93
}
85
94
86
95
private def selectJoinStrategy (
@@ -123,16 +132,29 @@ object DynamicJoinSelection extends Rule[LogicalPlan] with JoinSelectionHelper {
123
132
stage.computeStats().exists(_.rowCount.exists(_.toLong >= conf.broadcastMaxRowNum))
124
133
val adjustDemoteBroadcastHash = rowNumberExceeded || demoteBroadcastHash
125
134
126
- def collectShuffleStats (plan : LogicalPlan ): Seq [MapOutputStatistics ] = plan match {
135
+ var bucketedPlan = false
136
+ def collectShuffleStats (plan : LogicalPlan ): Seq [Option [MapOutputStatistics ]] = plan match {
127
137
case LogicalQueryStage (_, streamedStage : ShuffleQueryStageExec )
128
138
if streamedStage.isMaterialized && streamedStage.mapStats.isDefined =>
129
- Seq (streamedStage.mapStats.get)
130
- case _ => plan.children.flatMap(collectShuffleStats)
139
+ Seq (streamedStage.mapStats)
140
+ case LogicalQueryStage (_, _ : ShuffleQueryStageExec ) => Seq (None )
141
+ case _ if plan.children.nonEmpty => plan.children.flatMap(collectShuffleStats)
142
+ case _ =>
143
+ bucketedPlan = true
144
+ Seq ()
131
145
}
132
- val preferShuffleHash =
133
- preferShuffledHashJoin(stage.mapStats.get, collectShuffleStats(streamedPlan))
134
146
135
- logInfo(s " canBroadcastPlan = $canBroadcastPlan, rowNumberExceeded = " +
147
+ val streamedStats = collectShuffleStats(streamedPlan)
148
+ val allStats = Array (stage.mapStats) ++ streamedStats
149
+
150
+ val shuffleMaterialized =
151
+ allStats.forall(_.isDefined) &&
152
+ allStats.map(_.get.bytesByPartitionId.length).distinct.length == 1
153
+ val preferShuffleHash = ! bucketedPlan && shuffleMaterialized &&
154
+ preferShuffledHashJoin(stage.mapStats.get, streamedStats.map(_.get))
155
+
156
+ logInfo(s " isLeft = $isLeft, shuffleMaterialized = $shuffleMaterialized, " +
157
+ s " canBroadcastPlan = $canBroadcastPlan, rowNumberExceeded = " +
136
158
s " $rowNumberExceeded, adjustDemoteBroadcastHash = $adjustDemoteBroadcastHash, " +
137
159
s " preferShuffleHash = $preferShuffleHash" )
138
160
if (adjustDemoteBroadcastHash && preferShuffleHash) {
@@ -150,24 +172,18 @@ object DynamicJoinSelection extends Rule[LogicalPlan] with JoinSelectionHelper {
150
172
}
151
173
152
174
def apply (plan : LogicalPlan ): LogicalPlan = plan.transformDown {
153
- case j @ ExtractEquiJoinKeys (_, _, _, _, left, right, hint) =>
154
- if (left.getTagValue(USER_DEFINED_HINT_TAG ).isEmpty) {
155
- left.setTagValue(USER_DEFINED_HINT_TAG , hint.leftHint.exists(_.strategy.isDefined))
156
- }
157
- if (right.getTagValue(USER_DEFINED_HINT_TAG ).isEmpty) {
158
- right.setTagValue(USER_DEFINED_HINT_TAG , hint.rightHint.exists(_.strategy.isDefined))
159
- }
175
+ case j @ ExtractEquiJoinKeys (_, _, _, _, _, _, hint) =>
160
176
var newHint = hint
161
- if (! left.getTagValue(USER_DEFINED_HINT_TAG ).getOrElse(false )) {
177
+ if (! hint.leftHint.exists(_.strategy.isDefined) ||
178
+ hint.leftHint.get.strategy.contains(NO_BROADCAST_HASH )) {
162
179
selectJoinStrategy(j, true ).foreach { strategy =>
163
- logInfo(s " Set left side join strategy: $strategy" )
164
180
newHint = newHint.copy(leftHint =
165
181
Some (hint.leftHint.getOrElse(HintInfo ()).copy(strategy = Some (strategy))))
166
182
}
167
183
}
168
- if (! right.getTagValue(USER_DEFINED_HINT_TAG ).getOrElse(false )) {
184
+ if (! hint.rightHint.exists(_.strategy.isDefined) ||
185
+ hint.rightHint.get.strategy.contains(NO_BROADCAST_HASH )) {
169
186
selectJoinStrategy(j, false ).foreach { strategy =>
170
- logInfo(s " Set right side join strategy: $strategy" )
171
187
newHint = newHint.copy(rightHint =
172
188
Some (hint.rightHint.getOrElse(HintInfo ()).copy(strategy = Some (strategy))))
173
189
}
0 commit comments