Skip to content

Commit a89c364

Browse files
carsonwangvinodkc
authored andcommitted
[SPARK-28356][SQL] Do not reduce the number of partitions for repartition in adaptive execution
## What changes were proposed in this pull request? Adaptive execution reduces the number of post-shuffle partitions at runtime, even for shuffles caused by repartition. However, the user likely wants to get the desired number of partition when he calls repartition even in adaptive execution. This PR adds an internal config to control this and by default adaptive execution will not change the number of post-shuffle partition for repartition. ## How was this patch tested? New tests added. Closes apache#25121 from carsonwang/AE_repartition. Authored-by: Carson Wang <carson.wang@intel.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 47fbe53 commit a89c364

File tree

6 files changed

+25
-26
lines changed

6 files changed

+25
-26
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
703703

704704
case logical.Repartition(numPartitions, shuffle, child) =>
705705
if (shuffle) {
706-
ShuffleExchangeExec(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil
706+
ShuffleExchangeExec(RoundRobinPartitioning(numPartitions),
707+
planLater(child), canChangeNumPartitions = false) :: Nil
707708
} else {
708709
execution.CoalesceExec(numPartitions, planLater(child)) :: Nil
709710
}
@@ -736,7 +737,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
736737
case r: logical.Range =>
737738
execution.RangeExec(r) :: Nil
738739
case r: logical.RepartitionByExpression =>
739-
exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child)) :: Nil
740+
exchange.ShuffleExchangeExec(
741+
r.partitioning, planLater(r.child), canChangeNumPartitions = false) :: Nil
740742
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
741743
case r: LogicalRDD =>
742744
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,18 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] {
6161
// If not all leaf nodes are query stages, it's not safe to reduce the number of
6262
// shuffle partitions, because we may break the assumption that all children of a spark plan
6363
// have same number of output partitions.
64+
return plan
65+
}
66+
67+
val shuffleStages = plan.collect {
68+
case stage: ShuffleQueryStageExec => stage
69+
case ReusedQueryStageExec(_, stage: ShuffleQueryStageExec, _) => stage
70+
}
71+
// ShuffleExchanges introduced by repartition do not support changing the number of partitions.
72+
// We change the number of partitions in the stage only if all the ShuffleExchanges support it.
73+
if (!shuffleStages.forall(_.plan.canChangeNumPartitions)) {
6474
plan
6575
} else {
66-
val shuffleStages = plan.collect {
67-
case stage: ShuffleQueryStageExec => stage
68-
case ReusedQueryStageExec(_, stage: ShuffleQueryStageExec, _) => stage
69-
}
7076
val shuffleMetrics = shuffleStages.map { stage =>
7177
val metricsFuture = stage.mapOutputStatisticsFuture
7278
assert(metricsFuture.isCompleted, "ShuffleQueryStageExec should already be ready")
@@ -76,12 +82,7 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] {
7682
// `ShuffleQueryStageExec` gives null mapOutputStatistics when the input RDD has 0 partitions,
7783
// we should skip it when calculating the `partitionStartIndices`.
7884
val validMetrics = shuffleMetrics.filter(_ != null)
79-
// We may get different pre-shuffle partition number if user calls repartition manually.
80-
// We don't reduce shuffle partition number in that case.
81-
val distinctNumPreShufflePartitions =
82-
validMetrics.map(stats => stats.bytesByPartitionId.length).distinct
83-
84-
if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1) {
85+
if (validMetrics.nonEmpty) {
8586
val partitionStartIndices = estimatePartitionStartIndices(validMetrics.toArray)
8687
// This transformation adds new nodes, so we must use `transformUp` here.
8788
plan.transformUp {

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
9393
val defaultPartitioning = distribution.createPartitioning(targetNumPartitions)
9494
child match {
9595
// If child is an exchange, we replace it with a new one having defaultPartitioning.
96-
case ShuffleExchangeExec(_, c) => ShuffleExchangeExec(defaultPartitioning, c)
96+
case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(defaultPartitioning, c)
9797
case _ => ShuffleExchangeExec(defaultPartitioning, child)
9898
}
9999
}
@@ -199,7 +199,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
199199

200200
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
201201
// TODO: remove this after we create a physical operator for `RepartitionByExpression`.
202-
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child) =>
202+
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) =>
203203
child.outputPartitioning match {
204204
case lower: HashPartitioning if upper.semanticEquals(lower) => child
205205
case _ => operator

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordCo
4343
*/
4444
case class ShuffleExchangeExec(
4545
override val outputPartitioning: Partitioning,
46-
child: SparkPlan) extends Exchange {
46+
child: SparkPlan,
47+
canChangeNumPartitions: Boolean = true) extends Exchange {
4748

4849
// NOTE: coordinator can be null after serialization/deserialization,
4950
// e.g. it can be null on the Executor side

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1365,7 +1365,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
13651365
val agg = cp.groupBy('id % 2).agg(count('id))
13661366

13671367
agg.queryExecution.executedPlan.collectFirst {
1368-
case ShuffleExchangeExec(_, _: RDDScanExec) =>
1368+
case ShuffleExchangeExec(_, _: RDDScanExec, _) =>
13691369
case BroadcastExchangeExec(_, _: RDDScanExec) =>
13701370
}.foreach { _ =>
13711371
fail(

sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -574,22 +574,17 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA
574574
withSparkSession(test, 4, None)
575575
}
576576

577-
test("Union two datasets with different pre-shuffle partition number") {
577+
test("Do not reduce the number of shuffle partition for repartition") {
578578
val test: SparkSession => Unit = { spark: SparkSession =>
579-
val dataset1 = spark.range(3)
580-
val dataset2 = spark.range(3)
581-
582-
val resultDf = dataset1.repartition(2, dataset1.col("id"))
583-
.union(dataset2.repartition(3, dataset2.col("id"))).toDF()
579+
val ds = spark.range(3)
580+
val resultDf = ds.repartition(2, ds.col("id")).toDF()
584581

585582
checkAnswer(resultDf,
586-
Seq((0), (0), (1), (1), (2), (2)).map(i => Row(i)))
583+
Seq(0, 1, 2).map(i => Row(i)))
587584
val finalPlan = resultDf.queryExecution.executedPlan
588585
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
589-
// As the pre-shuffle partition number are different, we will skip reducing
590-
// the shuffle partition numbers.
591586
assert(finalPlan.collect { case p: CoalescedShuffleReaderExec => p }.length == 0)
592587
}
593-
withSparkSession(test, 100, None)
588+
withSparkSession(test, 200, None)
594589
}
595590
}

0 commit comments

Comments
 (0)