Skip to content

Commit edbdea1

Browse files
Jiajia Licarsonwang
authored andcommitted
Disable change reduce number if the joins are changed (apache#81)
* Disable change reduce number if the joins are changed * Change reduce number when all leaf nodes are shuffle querystages and not local shuffles * Ensure all leaf nodes are shuffle query stages * Update comments.
1 parent 61bd1c9 commit edbdea1

File tree

2 files changed

+105
-24
lines changed

2 files changed

+105
-24
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -113,38 +113,50 @@ abstract class QueryStage extends UnaryExecNode {
113113
val queryStageInputs: Seq[ShuffleQueryStageInput] = child.collect {
114114
case input: ShuffleQueryStageInput if !input.isLocalShuffle => input
115115
}
116-
val childMapOutputStatistics = queryStageInputs.map(_.childStage.mapOutputStatistics)
117-
.filter(_ != null).toArray
118-
// Right now, Adaptive execution only support HashPartitionings.
119-
val supportAdaptive = queryStageInputs.forall{
116+
117+
val skewedShuffleQueryStageInputs: Seq[SkewedShuffleQueryStageInput] = child.collect {
118+
case input: SkewedShuffleQueryStageInput => input
119+
}
120+
121+
val leafNodes = child.collect {
122+
case s: SparkPlan if s.children.isEmpty => s
123+
}
124+
125+
// Ensure all leaf nodes are shuffle query stages
126+
if (leafNodes.length == queryStageInputs.length + skewedShuffleQueryStageInputs.length) {
127+
val childMapOutputStatistics = queryStageInputs.map(_.childStage.mapOutputStatistics)
128+
.filter(_ != null).toArray
129+
// Right now, Adaptive execution only support HashPartitionings.
130+
val supportAdaptive = queryStageInputs.forall {
120131
_.outputPartitioning match {
121132
case hash: HashPartitioning => true
122133
case collection: PartitioningCollection =>
123134
collection.partitionings.forall(_.isInstanceOf[HashPartitioning])
124135
case _ => false
125136
}
126-
}
137+
}
127138

128-
if (childMapOutputStatistics.length > 0 && supportAdaptive) {
129-
val exchangeCoordinator = new ExchangeCoordinator(
130-
conf.targetPostShuffleInputSize,
131-
conf.adaptiveTargetPostShuffleRowCount,
132-
conf.minNumPostShufflePartitions)
133-
134-
if (queryStageInputs.length == 2 && queryStageInputs.forall(_.skewedPartitions.isDefined)) {
135-
// If a skewed join is detected and optimized, we will omit the skewed partitions when
136-
// estimate the partition start and end indices.
137-
val (partitionStartIndices, partitionEndIndices) =
138-
exchangeCoordinator.estimatePartitionStartEndIndices(
139-
childMapOutputStatistics, queryStageInputs(0).skewedPartitions.get)
140-
queryStageInputs.foreach { i =>
141-
i.partitionStartIndices = Some(partitionStartIndices)
142-
i.partitionEndIndices = Some(partitionEndIndices)
139+
if (childMapOutputStatistics.length > 0 && supportAdaptive) {
140+
val exchangeCoordinator = new ExchangeCoordinator(
141+
conf.targetPostShuffleInputSize,
142+
conf.adaptiveTargetPostShuffleRowCount,
143+
conf.minNumPostShufflePartitions)
144+
145+
if (queryStageInputs.length == 2 && queryStageInputs.forall(_.skewedPartitions.isDefined)) {
146+
// If a skewed join is detected and optimized, we will omit the skewed partitions when
147+
// estimate the partition start and end indices.
148+
val (partitionStartIndices, partitionEndIndices) =
149+
exchangeCoordinator.estimatePartitionStartEndIndices(
150+
childMapOutputStatistics, queryStageInputs(0).skewedPartitions.get)
151+
queryStageInputs.foreach { i =>
152+
i.partitionStartIndices = Some(partitionStartIndices)
153+
i.partitionEndIndices = Some(partitionEndIndices)
154+
}
155+
} else {
156+
val partitionStartIndices =
157+
exchangeCoordinator.estimatePartitionStartIndices(childMapOutputStatistics)
158+
queryStageInputs.foreach(_.partitionStartIndices = Some(partitionStartIndices))
143159
}
144-
} else {
145-
val partitionStartIndices =
146-
exchangeCoordinator.estimatePartitionStartIndices(childMapOutputStatistics)
147-
queryStageInputs.foreach(_.partitionStartIndices = Some(partitionStartIndices))
148160
}
149161
}
150162

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryStageSuite.scala

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,75 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll {
262262
}
263263
}
264264

265+
test("One of two sort merge inner joins to broadcast join") {
266+
// t1 is smaller than spark.sql.adaptiveBroadcastJoinThreshold
267+
// t2 and t3 are greater than spark.sql.adaptiveBroadcastJoinThreshold
268+
// Join1 is changed to broadcast join.
269+
//
270+
// Join2
271+
// / \
272+
// Join1 Ex (Exchange)
273+
// / \ \
274+
// Ex Ex t3
275+
// / \
276+
// t1 t2
277+
val spark = defaultSparkSession
278+
spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_ALLOW_ADDITIONAL_SHUFFLE.key, "true")
279+
withSparkSession(spark) { spark: SparkSession =>
280+
val df1 =
281+
spark
282+
.range(0, 1000, 1, numInputPartitions)
283+
.selectExpr("id % 500 as key1", "id as value1")
284+
val df2 =
285+
spark
286+
.range(0, 1000, 1, numInputPartitions)
287+
.selectExpr("id % 500 as key2", "id as value2")
288+
val df3 =
289+
spark
290+
.range(0, 1500, 1, numInputPartitions)
291+
.selectExpr("id % 500 as key3", "id as value3")
292+
293+
val join =
294+
df1
295+
.join(df2, col("key1") === col("key2"))
296+
.join(df3, col("key2") === col("key3"))
297+
.select(col("key3"), col("value1"))
298+
299+
// Before Execution, there is two SortMergeJoins
300+
val smjBeforeExecution = join.queryExecution.executedPlan.collect {
301+
case smj: SortMergeJoinExec => smj
302+
}
303+
assert(smjBeforeExecution.length === 2)
304+
305+
// Check the answer.
306+
val partResult =
307+
spark
308+
.range(0, 1000)
309+
.selectExpr("id % 500 as key", "id as value")
310+
.union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value"))
311+
val expectedAnswer = partResult.union(partResult).union(partResult)
312+
checkAnswer(
313+
join,
314+
expectedAnswer.collect())
315+
316+
// During execution, one SortMergeJoin is changed to BroadcastHashJoin
317+
val numSmjAfterExecution = join.queryExecution.executedPlan.collect {
318+
case smj: SortMergeJoinExec => smj
319+
}.length
320+
assert(numSmjAfterExecution === 1)
321+
322+
val numBhjAfterExecution = join.queryExecution.executedPlan.collect {
323+
case bhj: BroadcastHashJoinExec => bhj
324+
}.length
325+
assert(numBhjAfterExecution === 1)
326+
327+
val queryStageInputs = join.queryExecution.executedPlan.collect {
328+
case q: QueryStageInput => q
329+
}
330+
assert(queryStageInputs.length === 3)
331+
}
332+
}
333+
265334
test("Reuse QueryStage in adaptive execution") {
266335
withSparkSession(defaultSparkSession) { spark: SparkSession =>
267336
val df = spark.range(0, 1000, 1, numInputPartitions).toDF()

0 commit comments

Comments
 (0)