Skip to content

Commit dfa3978

Browse files
maryannxuegatorsmile
authored andcommitted
[SPARK-33551][SQL] Do not use custom shuffle reader for repartition
### What changes were proposed in this pull request? This PR fixes an AQE issue where local shuffle reader, partition coalescing, or skew join optimization can be mistakenly applied to a shuffle introduced by repartition or a regular shuffle that logically replaces a repartition shuffle. The proposed solution checks for the presence of any repartition shuffle and filters out not applicable optimization rules for the final stage in an AQE plan. ### Why are the changes needed? Without the change, the output of a repartition query may not be correct. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added UT. Closes #30494 from maryannxue/csr-repartition. Authored-by: Maryann Xue <maryann.xue@gmail.com> Signed-off-by: Xiao Li <gatorsmile@gmail.com>
1 parent ed9e6fc commit dfa3978

File tree

7 files changed

+187
-29
lines changed

7 files changed

+187
-29
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ object SQLConf {
509509
"'spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes'")
510510
.version("3.0.0")
511511
.intConf
512-
.checkValue(_ > 0, "The skew factor must be positive.")
512+
.checkValue(_ >= 0, "The skew factor cannot be negative.")
513513
.createWithDefault(5)
514514

515515
val SKEW_JOIN_SKEWED_PARTITION_THRESHOLD =

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ import org.apache.spark.sql.catalyst.trees.TreeNodeTag
3737
import org.apache.spark.sql.execution._
3838
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._
3939
import org.apache.spark.sql.execution.bucketing.DisableUnnecessaryBucketedScan
40-
import org.apache.spark.sql.execution.command.DataWritingCommandExec
41-
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
4240
import org.apache.spark.sql.execution.exchange._
4341
import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric}
4442
import org.apache.spark.sql.internal.SQLConf
@@ -104,23 +102,30 @@ case class AdaptiveSparkPlanExec(
104102
OptimizeLocalShuffleReader
105103
)
106104

107-
private def finalStageOptimizerRules: Seq[Rule[SparkPlan]] =
108-
context.qe.sparkPlan match {
109-
case _: DataWritingCommandExec | _: V2TableWriteExec =>
110-
// SPARK-32932: Local shuffle reader could break partitioning that works best
111-
// for the following writing command
112-
queryStageOptimizerRules.filterNot(_ == OptimizeLocalShuffleReader)
113-
case _ =>
114-
queryStageOptimizerRules
115-
}
116-
117105
// A list of physical optimizer rules to be applied right after a new stage is created. The input
118106
// plan to these rules has exchange as its root node.
119107
@transient private val postStageCreationRules = Seq(
120108
ApplyColumnarRulesAndInsertTransitions(context.session.sessionState.columnarRules),
121109
CollapseCodegenStages()
122110
)
123111

112+
// The partitioning of the query output depends on the shuffle(s) in the final stage. If the
113+
// original plan contains a repartition operator, we need to preserve the specified partitioning,
114+
// whether or not the repartition-introduced shuffle is optimized out because of an underlying
115+
// shuffle of the same partitioning. Thus, we need to exclude some `CustomShuffleReaderRule`s
116+
// from the final stage, depending on the presence and properties of repartition operators.
117+
private def finalStageOptimizerRules: Seq[Rule[SparkPlan]] = {
118+
val origins = inputPlan.collect {
119+
case s: ShuffleExchangeLike => s.shuffleOrigin
120+
}
121+
val allRules = queryStageOptimizerRules ++ postStageCreationRules
122+
allRules.filter {
123+
case c: CustomShuffleReaderRule =>
124+
origins.forall(c.supportedShuffleOrigins.contains)
125+
case _ => true
126+
}
127+
}
128+
124129
@transient private val costEvaluator = SimpleCostEvaluator
125130

126131
@transient private val initialPlan = context.session.withActive {
@@ -249,7 +254,7 @@ case class AdaptiveSparkPlanExec(
249254
// Run the final plan when there's no more unfinished stages.
250255
currentPhysicalPlan = applyPhysicalRules(
251256
result.newPlan,
252-
finalStageOptimizerRules ++ postStageCreationRules,
257+
finalStageOptimizerRules,
253258
Some((planChangeLogger, "AQE Final Query Stage Optimization")))
254259
isFinalPlan = true
255260
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,18 @@ package org.apache.spark.sql.execution.adaptive
1919

2020
import org.apache.spark.sql.SparkSession
2121
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
22-
import org.apache.spark.sql.catalyst.rules.Rule
2322
import org.apache.spark.sql.execution.SparkPlan
24-
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REPARTITION, ShuffleExchangeLike}
23+
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REPARTITION, ShuffleExchangeLike, ShuffleOrigin}
2524
import org.apache.spark.sql.internal.SQLConf
2625

2726
/**
2827
* A rule to coalesce the shuffle partitions based on the map output statistics, which can
2928
* avoid many small reduce tasks that hurt performance.
3029
*/
31-
case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPlan] {
30+
case class CoalesceShufflePartitions(session: SparkSession) extends CustomShuffleReaderRule {
31+
32+
override val supportedShuffleOrigins: Seq[ShuffleOrigin] = Seq(ENSURE_REQUIREMENTS, REPARTITION)
33+
3234
override def apply(plan: SparkPlan): SparkPlan = {
3335
if (!conf.coalesceShufflePartitionsEnabled) {
3436
return plan
@@ -86,7 +88,6 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
8688
}
8789

8890
private def supportCoalesce(s: ShuffleExchangeLike): Boolean = {
89-
s.outputPartitioning != SinglePartition &&
90-
(s.shuffleOrigin == ENSURE_REQUIREMENTS || s.shuffleOrigin == REPARTITION)
91+
s.outputPartitioning != SinglePartition && supportedShuffleOrigins.contains(s.shuffleOrigin)
9192
}
9293
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.adaptive
19+
20+
import org.apache.spark.sql.catalyst.rules.Rule
21+
import org.apache.spark.sql.execution.SparkPlan
22+
import org.apache.spark.sql.execution.exchange.ShuffleOrigin
23+
24+
/**
25+
* Adaptive Query Execution rule that may create [[CustomShuffleReaderExec]] on top of query stages.
26+
*/
27+
trait CustomShuffleReaderRule extends Rule[SparkPlan] {
28+
29+
/**
30+
* Returns the list of [[ShuffleOrigin]]s supported by this rule.
31+
*/
32+
def supportedShuffleOrigins: Seq[ShuffleOrigin]
33+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@ package org.apache.spark.sql.execution.adaptive
1919

2020
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
2121
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
22-
import org.apache.spark.sql.catalyst.rules.Rule
2322
import org.apache.spark.sql.execution._
24-
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ShuffleExchangeExec, ShuffleExchangeLike}
23+
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
2524
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
2625
import org.apache.spark.sql.internal.SQLConf
2726

@@ -34,7 +33,9 @@ import org.apache.spark.sql.internal.SQLConf
3433
* then run `EnsureRequirements` to check whether additional shuffle introduced.
3534
* If introduced, we will revert all the local readers.
3635
*/
37-
object OptimizeLocalShuffleReader extends Rule[SparkPlan] {
36+
object OptimizeLocalShuffleReader extends CustomShuffleReaderRule {
37+
38+
override val supportedShuffleOrigins: Seq[ShuffleOrigin] = Seq(ENSURE_REQUIREMENTS)
3839

3940
private val ensureRequirements = EnsureRequirements
4041

@@ -144,6 +145,6 @@ object OptimizeLocalShuffleReader extends Rule[SparkPlan] {
144145
}
145146

146147
private def supportLocalReader(s: ShuffleExchangeLike): Boolean = {
147-
s.outputPartitioning != SinglePartition && s.shuffleOrigin == ENSURE_REQUIREMENTS
148+
s.outputPartitioning != SinglePartition && supportedShuffleOrigins.contains(s.shuffleOrigin)
148149
}
149150
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ import org.apache.commons.io.FileUtils
2323

2424
import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
2525
import org.apache.spark.sql.catalyst.plans._
26-
import org.apache.spark.sql.catalyst.rules.Rule
2726
import org.apache.spark.sql.execution._
28-
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec}
27+
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ShuffleExchangeExec, ShuffleOrigin}
2928
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
3029
import org.apache.spark.sql.internal.SQLConf
3130

@@ -53,7 +52,9 @@ import org.apache.spark.sql.internal.SQLConf
5352
* Note that, when this rule is enabled, it also coalesces non-skewed partitions like
5453
* `CoalesceShufflePartitions` does.
5554
*/
56-
object OptimizeSkewedJoin extends Rule[SparkPlan] {
55+
object OptimizeSkewedJoin extends CustomShuffleReaderRule {
56+
57+
override val supportedShuffleOrigins: Seq[ShuffleOrigin] = Seq(ENSURE_REQUIREMENTS)
5758

5859
private val ensureRequirements = EnsureRequirements
5960

@@ -290,7 +291,9 @@ object OptimizeSkewedJoin extends Rule[SparkPlan] {
290291

291292
private object ShuffleStage {
292293
def unapply(plan: SparkPlan): Option[ShuffleStageInfo] = plan match {
293-
case s: ShuffleQueryStageExec if s.mapStats.isDefined =>
294+
case s: ShuffleQueryStageExec
295+
if s.mapStats.isDefined &&
296+
OptimizeSkewedJoin.supportedShuffleOrigins.contains(s.shuffle.shuffleOrigin) =>
294297
val mapStats = s.mapStats.get
295298
val sizes = mapStats.bytesByPartitionId
296299
val partitions = sizes.zipWithIndex.map {
@@ -299,7 +302,8 @@ private object ShuffleStage {
299302
Some(ShuffleStageInfo(s, mapStats, partitions))
300303

301304
case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs)
302-
if s.mapStats.isDefined && partitionSpecs.nonEmpty =>
305+
if s.mapStats.isDefined && partitionSpecs.nonEmpty &&
306+
OptimizeSkewedJoin.supportedShuffleOrigins.contains(s.shuffle.shuffleOrigin) =>
303307
val mapStats = s.mapStats.get
304308
val sizes = mapStats.bytesByPartitionId
305309
val partitions = partitionSpecs.map {

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, QueryExecuti
3030
import org.apache.spark.sql.execution.command.DataWritingCommandExec
3131
import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
3232
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
33-
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec}
33+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, REPARTITION, REPARTITION_WITH_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
3434
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, SortMergeJoinExec}
3535
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
3636
import org.apache.spark.sql.functions._
@@ -1317,4 +1317,118 @@ class AdaptiveQueryExecSuite
13171317
checkNumLocalShuffleReaders(df.queryExecution.executedPlan, numShufflesWithoutLocalReader = 1)
13181318
}
13191319
}
1320+
1321+
test("SPARK-33551: Do not use custom shuffle reader for repartition") {
1322+
def hasRepartitionShuffle(plan: SparkPlan): Boolean = {
1323+
find(plan) {
1324+
case s: ShuffleExchangeLike =>
1325+
s.shuffleOrigin == REPARTITION || s.shuffleOrigin == REPARTITION_WITH_NUM
1326+
case _ => false
1327+
}.isDefined
1328+
}
1329+
1330+
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
1331+
SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
1332+
val df = sql(
1333+
"""
1334+
|SELECT * FROM (
1335+
| SELECT * FROM testData WHERE key = 1
1336+
|)
1337+
|RIGHT OUTER JOIN testData2
1338+
|ON value = b
1339+
""".stripMargin)
1340+
1341+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
1342+
// Repartition with no partition num specified.
1343+
val dfRepartition = df.repartition('b)
1344+
dfRepartition.collect()
1345+
val plan = dfRepartition.queryExecution.executedPlan
1346+
// The top shuffle from repartition is optimized out.
1347+
assert(!hasRepartitionShuffle(plan))
1348+
val bhj = findTopLevelBroadcastHashJoin(plan)
1349+
assert(bhj.length == 1)
1350+
checkNumLocalShuffleReaders(plan, 1)
1351+
// Probe side is coalesced.
1352+
val customReader = bhj.head.right.find(_.isInstanceOf[CustomShuffleReaderExec])
1353+
assert(customReader.isDefined)
1354+
assert(customReader.get.asInstanceOf[CustomShuffleReaderExec].hasCoalescedPartition)
1355+
1356+
// Repartition with partition default num specified.
1357+
val dfRepartitionWithNum = df.repartition(5, 'b)
1358+
dfRepartitionWithNum.collect()
1359+
val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
1360+
// The top shuffle from repartition is optimized out.
1361+
assert(!hasRepartitionShuffle(planWithNum))
1362+
val bhjWithNum = findTopLevelBroadcastHashJoin(planWithNum)
1363+
assert(bhjWithNum.length == 1)
1364+
checkNumLocalShuffleReaders(planWithNum, 1)
1365+
// Probe side is not coalesced.
1366+
assert(bhjWithNum.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]).isEmpty)
1367+
1368+
// Repartition with partition non-default num specified.
1369+
val dfRepartitionWithNum2 = df.repartition(3, 'b)
1370+
dfRepartitionWithNum2.collect()
1371+
val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan
1372+
// The top shuffle from repartition is not optimized out, and this is the only shuffle that
1373+
// does not have local shuffle reader.
1374+
assert(hasRepartitionShuffle(planWithNum2))
1375+
val bhjWithNum2 = findTopLevelBroadcastHashJoin(planWithNum2)
1376+
assert(bhjWithNum2.length == 1)
1377+
checkNumLocalShuffleReaders(planWithNum2, 1)
1378+
val customReader2 = bhjWithNum2.head.right.find(_.isInstanceOf[CustomShuffleReaderExec])
1379+
assert(customReader2.isDefined)
1380+
assert(customReader2.get.asInstanceOf[CustomShuffleReaderExec].isLocalReader)
1381+
}
1382+
1383+
// Force skew join
1384+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
1385+
SQLConf.SKEW_JOIN_ENABLED.key -> "true",
1386+
SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "1",
1387+
SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0",
1388+
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") {
1389+
// Repartition with no partition num specified.
1390+
val dfRepartition = df.repartition('b)
1391+
dfRepartition.collect()
1392+
val plan = dfRepartition.queryExecution.executedPlan
1393+
// The top shuffle from repartition is optimized out.
1394+
assert(!hasRepartitionShuffle(plan))
1395+
val smj = findTopLevelSortMergeJoin(plan)
1396+
assert(smj.length == 1)
1397+
// No skew join due to the repartition.
1398+
assert(!smj.head.isSkewJoin)
1399+
// Both sides are coalesced.
1400+
val customReaders = collect(smj.head) {
1401+
case c: CustomShuffleReaderExec if c.hasCoalescedPartition => c
1402+
}
1403+
assert(customReaders.length == 2)
1404+
1405+
// Repartition with default partition num specified.
1406+
val dfRepartitionWithNum = df.repartition(5, 'b)
1407+
dfRepartitionWithNum.collect()
1408+
val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
1409+
// The top shuffle from repartition is optimized out.
1410+
assert(!hasRepartitionShuffle(planWithNum))
1411+
val smjWithNum = findTopLevelSortMergeJoin(planWithNum)
1412+
assert(smjWithNum.length == 1)
1413+
// No skew join due to the repartition.
1414+
assert(!smjWithNum.head.isSkewJoin)
1415+
// No coalesce due to the num in repartition.
1416+
val customReadersWithNum = collect(smjWithNum.head) {
1417+
case c: CustomShuffleReaderExec if c.hasCoalescedPartition => c
1418+
}
1419+
assert(customReadersWithNum.isEmpty)
1420+
1421+
// Repartition with default non-partition num specified.
1422+
val dfRepartitionWithNum2 = df.repartition(3, 'b)
1423+
dfRepartitionWithNum2.collect()
1424+
val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan
1425+
// The top shuffle from repartition is not optimized out.
1426+
assert(hasRepartitionShuffle(planWithNum2))
1427+
val smjWithNum2 = findTopLevelSortMergeJoin(planWithNum2)
1428+
assert(smjWithNum2.length == 1)
1429+
// Skew join can apply as the repartition is not optimized out.
1430+
assert(smjWithNum2.head.isSkewJoin)
1431+
}
1432+
}
1433+
}
13201434
}

0 commit comments

Comments
 (0)