From df83fd20b48867daf49313eff71674974fdcc4cb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 25 Nov 2020 02:02:32 +0000 Subject: [PATCH] [SPARK-33494][SQL][AQE] Do not use local shuffle reader for repartition This PR updates `ShuffleExchangeExec` to carry more information about how much we can change the partitioning. For `repartition(col)`, we should preserve the user-specified partitioning and don't apply the AQE local shuffle reader. Similar to `repartition(number, col)`, we should respect the user-specified partitioning. No a new test Closes #30432 from cloud-fan/aqe. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../spark/sql/execution/SparkStrategies.scala | 13 +++++---- .../adaptive/CoalesceShufflePartitions.scala | 9 +++++- .../adaptive/OptimizeLocalShuffleReader.scala | 13 ++++++--- .../exchange/ShuffleExchangeExec.scala | 28 +++++++++++++------ .../sql-tests/results/explain-aqe.sql.out | 20 ++++++------- .../sql-tests/results/explain.sql.out | 28 +++++++++---------- .../sql/SparkSessionExtensionSuite.scala | 6 ++-- .../adaptive/AdaptiveQueryExecSuite.scala | 10 +++++++ 8 files changed, 83 insertions(+), 44 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dbfd4bf7de440..fe1f45e34ce5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.aggregate.AggUtils import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{REPARTITION, REPARTITION_WITH_NUM, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.streaming._ @@ -754,7 +754,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { ShuffleExchangeExec(RoundRobinPartitioning(numPartitions), - planLater(child), noUserSpecifiedNumPartition = false) :: Nil + planLater(child), REPARTITION_WITH_NUM) :: Nil } else { execution.CoalesceExec(numPartitions, planLater(child)) :: Nil } @@ -787,9 +787,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: logical.Range => execution.RangeExec(r) :: Nil case r: logical.RepartitionByExpression => - val canChangeNumParts = r.optNumPartitions.isEmpty - exchange.ShuffleExchangeExec( - r.partitioning, planLater(r.child), canChangeNumParts) :: Nil + val shuffleOrigin = if (r.optNumPartitions.isEmpty) { + REPARTITION + } else { + REPARTITION_WITH_NUM + } + exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child), shuffleOrigin) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala index 096d65f16e42f..11c8fee2ed296 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REPARTITION, ShuffleExchangeLike} import org.apache.spark.sql.internal.SQLConf /** @@ -50,7 +52,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl val shuffleStages = collectShuffleStages(plan) // ShuffleExchanges introduced by repartition do not support changing the number of partitions. // We change the number of partitions in the stage only if all the ShuffleExchanges support it. - if (!shuffleStages.forall(_.shuffle.canChangeNumPartitions)) { + if (!shuffleStages.forall(s => supportCoalesce(s.shuffle))) { plan } else { // `ShuffleQueryStageExec#mapStats` returns None when the input RDD has 0 partitions, @@ -85,6 +87,11 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl } } } + + private def supportCoalesce(s: ShuffleExchangeLike): Boolean = { + s.outputPartitioning != SinglePartition && + (s.shuffleOrigin == ENSURE_REQUIREMENTS || s.shuffleOrigin == REPARTITION) + } } object CoalesceShufflePartitions { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala index 31d1f34b64a65..f4891f8b76567 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.execution.adaptive +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.internal.SQLConf @@ -142,9 +143,13 @@ object OptimizeLocalShuffleReader { def canUseLocalShuffleReader(plan: SparkPlan): Boolean = plan match { case s: ShuffleQueryStageExec => - s.shuffle.canChangeNumPartitions && s.mapStats.isDefined - case CustomShuffleReaderExec(s: ShuffleQueryStageExec, _, _) => - s.shuffle.canChangeNumPartitions && s.mapStats.isDefined + s.mapStats.isDefined && supportLocalReader(s.shuffle) + case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs, _) => + s.mapStats.isDefined && partitionSpecs.nonEmpty && supportLocalReader(s.shuffle) case _ => false } + + private def supportLocalReader(s: ShuffleExchangeLike): Boolean = { + s.outputPartitioning != SinglePartition && s.shuffleOrigin == ENSURE_REQUIREMENTS + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 24c736951fdc4..25462c28a4271 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -57,9 +57,9 @@ trait ShuffleExchangeLike extends Exchange { def numPartitions: Int /** - * Returns whether the shuffle partition number can be changed. + * The origin of this shuffle operator. */ - def canChangeNumPartitions: Boolean + def shuffleOrigin: ShuffleOrigin /** * The asynchronous job that materializes the shuffle. @@ -77,18 +77,30 @@ trait ShuffleExchangeLike extends Exchange { def runtimeStatistics: Statistics } +// Describes where the shuffle operator comes from. +sealed trait ShuffleOrigin + +// Indicates that the shuffle operator was added by the internal `EnsureRequirements` rule. It +// means that the shuffle operator is used to ensure internal data partitioning requirements and +// Spark is free to optimize it as long as the requirements are still ensured. +case object ENSURE_REQUIREMENTS extends ShuffleOrigin + +// Indicates that the shuffle operator was added by the user-specified repartition operator. Spark +// can still optimize it via changing shuffle partition number, as data partitioning won't change. +case object REPARTITION extends ShuffleOrigin + +// Indicates that the shuffle operator was added by the user-specified repartition operator with +// a certain partition number. Spark can't optimize it. +case object REPARTITION_WITH_NUM extends ShuffleOrigin + /** * Performs a shuffle that will result in the desired partitioning. */ case class ShuffleExchangeExec( override val outputPartitioning: Partitioning, child: SparkPlan, - noUserSpecifiedNumPartition: Boolean = true) extends ShuffleExchangeLike { - - // If users specify the num partitions via APIs like `repartition`, we shouldn't change it. - // For `SinglePartition`, it requires exactly one partition and we can't change it either. - def canChangeNumPartitions: Boolean = - noUserSpecifiedNumPartition && outputPartitioning != SinglePartition + shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS) + extends ShuffleExchangeLike { private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out index d13649354ae42..86c005bca7126 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 23 +-- Number of queries: 24 -- !query @@ -89,7 +89,7 @@ Results [2]: [key#x, max#x] (5) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (6) HashAggregate Input [2]: [key#x, max#x] @@ -100,7 +100,7 @@ Results [2]: [key#x, max(val#x)#x AS max(val)#x] (7) Exchange Input [2]: [key#x, max(val)#x] -Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), true, [id=#x] +Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [id=#x] (8) Sort Input [2]: [key#x, max(val)#x] @@ -158,7 +158,7 @@ Results [2]: [key#x, max#x] (5) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (6) HashAggregate Input [2]: [key#x, max#x] @@ -245,7 +245,7 @@ Results [2]: [key#x, val#x] (9) Exchange Input [2]: [key#x, val#x] -Arguments: hashpartitioning(key#x, val#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, val#x, 4), ENSURE_REQUIREMENTS, [id=#x] (10) HashAggregate Input [2]: [key#x, val#x] @@ -613,7 +613,7 @@ Results [2]: [key#x, max#x] (5) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (6) HashAggregate Input [2]: [key#x, max#x] @@ -647,7 +647,7 @@ Results [2]: [key#x, max#x] (11) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (12) HashAggregate Input [2]: [key#x, max#x] @@ -730,7 +730,7 @@ Results [3]: [count#xL, sum#xL, count#xL] (3) Exchange Input [3]: [count#xL, sum#xL, count#xL] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (4) HashAggregate Input [3]: [count#xL, sum#xL, count#xL] @@ -776,7 +776,7 @@ Results [2]: [key#x, buf#x] (3) Exchange Input [2]: [key#x, buf#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (4) ObjectHashAggregate Input [2]: [key#x, buf#x] @@ -828,7 +828,7 @@ Results [2]: [key#x, min#x] (4) Exchange Input [2]: [key#x, min#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (5) Sort Input [2]: [key#x, min#x] diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index 4c5b8407bdf20..e7eb99abbccda 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 23 +-- Number of queries: 24 -- !query @@ -92,7 +92,7 @@ Results [2]: [key#x, max#x] (6) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (7) HashAggregate [codegen id : 2] Input [2]: [key#x, max#x] @@ -103,7 +103,7 @@ Results [2]: [key#x, max(val#x)#x AS max(val)#x] (8) Exchange Input [2]: [key#x, max(val)#x] -Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), true, [id=#x] +Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [id=#x] (9) Sort [codegen id : 3] Input [2]: [key#x, max(val)#x] @@ -160,7 +160,7 @@ Results [2]: [key#x, max#x] (6) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (7) HashAggregate [codegen id : 2] Input [2]: [key#x, max#x] @@ -250,7 +250,7 @@ Results [2]: [key#x, val#x] (11) Exchange Input [2]: [key#x, val#x] -Arguments: hashpartitioning(key#x, val#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, val#x, 4), ENSURE_REQUIREMENTS, [id=#x] (12) HashAggregate [codegen id : 4] Input [2]: [key#x, val#x] @@ -469,7 +469,7 @@ Results [1]: [max#x] (10) Exchange Input [1]: [max#x] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (11) HashAggregate [codegen id : 2] Input [1]: [max#x] @@ -516,7 +516,7 @@ Results [1]: [max#x] (17) Exchange Input [1]: [max#x] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (18) HashAggregate [codegen id : 2] Input [1]: [max#x] @@ -600,7 +600,7 @@ Results [1]: [max#x] (9) Exchange Input [1]: [max#x] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (10) HashAggregate [codegen id : 2] Input [1]: [max#x] @@ -647,7 +647,7 @@ Results [2]: [sum#x, count#xL] (16) Exchange Input [2]: [sum#x, count#xL] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (17) HashAggregate [codegen id : 2] Input [2]: [sum#x, count#xL] @@ -713,7 +713,7 @@ Results [2]: [sum#x, count#xL] (7) Exchange Input [2]: [sum#x, count#xL] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (8) HashAggregate [codegen id : 2] Input [2]: [sum#x, count#xL] @@ -851,7 +851,7 @@ Results [2]: [key#x, max#x] (6) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (7) HashAggregate [codegen id : 4] Input [2]: [key#x, max#x] @@ -943,7 +943,7 @@ Results [3]: [count#xL, sum#xL, count#xL] (4) Exchange Input [3]: [count#xL, sum#xL, count#xL] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (5) HashAggregate [codegen id : 2] Input [3]: [count#xL, sum#xL, count#xL] @@ -988,7 +988,7 @@ Results [2]: [key#x, buf#x] (4) Exchange Input [2]: [key#x, buf#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (5) ObjectHashAggregate Input [2]: [key#x, buf#x] @@ -1039,7 +1039,7 @@ Results [2]: [key#x, min#x] (5) Exchange Input [2]: [key#x, min#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (6) Sort [codegen id : 2] Input [2]: [key#x, min#x] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index e5e8bc6917799..8f5547e41ae1e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE @@ -763,7 +763,9 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] { case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleExchangeLike { override def numMappers: Int = delegate.numMappers override def numPartitions: Int = delegate.numPartitions - override def canChangeNumPartitions: Boolean = delegate.canChangeNumPartitions + override def shuffleOrigin: ShuffleOrigin = { + delegate.shuffleOrigin + } override def mapOutputStatisticsFuture: Future[MapOutputStatistics] = delegate.mapOutputStatisticsFuture override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index c9992e484e672..cc8611ba40601 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1013,4 +1013,14 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-33494: Do not use local shuffle reader for repartition") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val df = spark.table("testData").repartition('key) + df.collect() + // local shuffle reader breaks partitioning and shouldn't be used for repartition operation + // which is specified by users. + checkNumLocalShuffleReaders(df.queryExecution.executedPlan, numShufflesWithoutLocalReader = 1) + } + } }