Skip to content

Commit

Permalink
[SPARK-33494][SQL][AQE] Do not use local shuffle reader for repartition
Browse files Browse the repository at this point in the history
This PR updates `ShuffleExchangeExec` to carry more information about how much we can change the partitioning. For `repartition(col)`, we should preserve the user-specified partitioning and don't apply the AQE local shuffle reader.

Similar to `repartition(number, col)`, we should respect the user-specified partitioning.

No

a new test

Closes apache#30432 from cloud-fan/aqe.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
cloud-fan authored and Lorenzo Martini committed May 18, 2021
1 parent dd66b31 commit df83fd2
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.aggregate.AggUtils
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.exchange.{REPARTITION, REPARTITION_WITH_NUM, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.execution.python._
import org.apache.spark.sql.execution.streaming._
Expand Down Expand Up @@ -754,7 +754,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Repartition(numPartitions, shuffle, child) =>
if (shuffle) {
ShuffleExchangeExec(RoundRobinPartitioning(numPartitions),
planLater(child), noUserSpecifiedNumPartition = false) :: Nil
planLater(child), REPARTITION_WITH_NUM) :: Nil
} else {
execution.CoalesceExec(numPartitions, planLater(child)) :: Nil
}
Expand Down Expand Up @@ -787,9 +787,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case r: logical.Range =>
execution.RangeExec(r) :: Nil
case r: logical.RepartitionByExpression =>
val canChangeNumParts = r.optNumPartitions.isEmpty
exchange.ShuffleExchangeExec(
r.partitioning, planLater(r.child), canChangeNumParts) :: Nil
val shuffleOrigin = if (r.optNumPartitions.isEmpty) {
REPARTITION
} else {
REPARTITION_WITH_NUM
}
exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child), shuffleOrigin) :: Nil
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
case r: LogicalRDD =>
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package org.apache.spark.sql.execution.adaptive

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REPARTITION, ShuffleExchangeLike}
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down Expand Up @@ -50,7 +52,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
val shuffleStages = collectShuffleStages(plan)
// ShuffleExchanges introduced by repartition do not support changing the number of partitions.
// We change the number of partitions in the stage only if all the ShuffleExchanges support it.
if (!shuffleStages.forall(_.shuffle.canChangeNumPartitions)) {
if (!shuffleStages.forall(s => supportCoalesce(s.shuffle))) {
plan
} else {
// `ShuffleQueryStageExec#mapStats` returns None when the input RDD has 0 partitions,
Expand Down Expand Up @@ -85,6 +87,11 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
}
}
}

private def supportCoalesce(s: ShuffleExchangeLike): Boolean = {
s.outputPartitioning != SinglePartition &&
(s.shuffleOrigin == ENSURE_REQUIREMENTS || s.shuffleOrigin == REPARTITION)
}
}

object CoalesceShufflePartitions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql.execution.adaptive

import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.internal.SQLConf

Expand Down Expand Up @@ -142,9 +143,13 @@ object OptimizeLocalShuffleReader {

def canUseLocalShuffleReader(plan: SparkPlan): Boolean = plan match {
case s: ShuffleQueryStageExec =>
s.shuffle.canChangeNumPartitions && s.mapStats.isDefined
case CustomShuffleReaderExec(s: ShuffleQueryStageExec, _, _) =>
s.shuffle.canChangeNumPartitions && s.mapStats.isDefined
s.mapStats.isDefined && supportLocalReader(s.shuffle)
case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs, _) =>
s.mapStats.isDefined && partitionSpecs.nonEmpty && supportLocalReader(s.shuffle)
case _ => false
}

private def supportLocalReader(s: ShuffleExchangeLike): Boolean = {
s.outputPartitioning != SinglePartition && s.shuffleOrigin == ENSURE_REQUIREMENTS
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ trait ShuffleExchangeLike extends Exchange {
def numPartitions: Int

/**
* Returns whether the shuffle partition number can be changed.
* The origin of this shuffle operator.
*/
def canChangeNumPartitions: Boolean
def shuffleOrigin: ShuffleOrigin

/**
* The asynchronous job that materializes the shuffle.
Expand All @@ -77,18 +77,30 @@ trait ShuffleExchangeLike extends Exchange {
def runtimeStatistics: Statistics
}

// Describes where the shuffle operator comes from.
sealed trait ShuffleOrigin

// Indicates that the shuffle operator was added by the internal `EnsureRequirements` rule. It
// means that the shuffle operator is used to ensure internal data partitioning requirements and
// Spark is free to optimize it as long as the requirements are still ensured.
case object ENSURE_REQUIREMENTS extends ShuffleOrigin

// Indicates that the shuffle operator was added by the user-specified repartition operator. Spark
// can still optimize it via changing shuffle partition number, as data partitioning won't change.
case object REPARTITION extends ShuffleOrigin

// Indicates that the shuffle operator was added by the user-specified repartition operator with
// a certain partition number. Spark can't optimize it.
case object REPARTITION_WITH_NUM extends ShuffleOrigin

/**
* Performs a shuffle that will result in the desired partitioning.
*/
case class ShuffleExchangeExec(
override val outputPartitioning: Partitioning,
child: SparkPlan,
noUserSpecifiedNumPartition: Boolean = true) extends ShuffleExchangeLike {

// If users specify the num partitions via APIs like `repartition`, we shouldn't change it.
// For `SinglePartition`, it requires exactly one partition and we can't change it either.
def canChangeNumPartitions: Boolean =
noUserSpecifiedNumPartition && outputPartitioning != SinglePartition
shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS)
extends ShuffleExchangeLike {

private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
Expand Down
20 changes: 10 additions & 10 deletions sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 23
-- Number of queries: 24


-- !query
Expand Down Expand Up @@ -89,7 +89,7 @@ Results [2]: [key#x, max#x]

(5) Exchange
Input [2]: [key#x, max#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]

(6) HashAggregate
Input [2]: [key#x, max#x]
Expand All @@ -100,7 +100,7 @@ Results [2]: [key#x, max(val#x)#x AS max(val)#x]

(7) Exchange
Input [2]: [key#x, max(val)#x]
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), true, [id=#x]
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [id=#x]

(8) Sort
Input [2]: [key#x, max(val)#x]
Expand Down Expand Up @@ -158,7 +158,7 @@ Results [2]: [key#x, max#x]

(5) Exchange
Input [2]: [key#x, max#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]

(6) HashAggregate
Input [2]: [key#x, max#x]
Expand Down Expand Up @@ -245,7 +245,7 @@ Results [2]: [key#x, val#x]

(9) Exchange
Input [2]: [key#x, val#x]
Arguments: hashpartitioning(key#x, val#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, val#x, 4), ENSURE_REQUIREMENTS, [id=#x]

(10) HashAggregate
Input [2]: [key#x, val#x]
Expand Down Expand Up @@ -613,7 +613,7 @@ Results [2]: [key#x, max#x]

(5) Exchange
Input [2]: [key#x, max#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]

(6) HashAggregate
Input [2]: [key#x, max#x]
Expand Down Expand Up @@ -647,7 +647,7 @@ Results [2]: [key#x, max#x]

(11) Exchange
Input [2]: [key#x, max#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]

(12) HashAggregate
Input [2]: [key#x, max#x]
Expand Down Expand Up @@ -730,7 +730,7 @@ Results [3]: [count#xL, sum#xL, count#xL]

(3) Exchange
Input [3]: [count#xL, sum#xL, count#xL]
Arguments: SinglePartition, true, [id=#x]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]

(4) HashAggregate
Input [3]: [count#xL, sum#xL, count#xL]
Expand Down Expand Up @@ -776,7 +776,7 @@ Results [2]: [key#x, buf#x]

(3) Exchange
Input [2]: [key#x, buf#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]

(4) ObjectHashAggregate
Input [2]: [key#x, buf#x]
Expand Down Expand Up @@ -828,7 +828,7 @@ Results [2]: [key#x, min#x]

(4) Exchange
Input [2]: [key#x, min#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]

(5) Sort
Input [2]: [key#x, min#x]
Expand Down
28 changes: 14 additions & 14 deletions sql/core/src/test/resources/sql-tests/results/explain.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 23
-- Number of queries: 24


-- !query
Expand Down Expand Up @@ -92,7 +92,7 @@ Results [2]: [key#x, max#x]

(6) Exchange
Input [2]: [key#x, max#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]

(7) HashAggregate [codegen id : 2]
Input [2]: [key#x, max#x]
Expand All @@ -103,7 +103,7 @@ Results [2]: [key#x, max(val#x)#x AS max(val)#x]

(8) Exchange
Input [2]: [key#x, max(val)#x]
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), true, [id=#x]
Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [id=#x]

(9) Sort [codegen id : 3]
Input [2]: [key#x, max(val)#x]
Expand Down Expand Up @@ -160,7 +160,7 @@ Results [2]: [key#x, max#x]

(6) Exchange
Input [2]: [key#x, max#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]

(7) HashAggregate [codegen id : 2]
Input [2]: [key#x, max#x]
Expand Down Expand Up @@ -250,7 +250,7 @@ Results [2]: [key#x, val#x]

(11) Exchange
Input [2]: [key#x, val#x]
Arguments: hashpartitioning(key#x, val#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, val#x, 4), ENSURE_REQUIREMENTS, [id=#x]

(12) HashAggregate [codegen id : 4]
Input [2]: [key#x, val#x]
Expand Down Expand Up @@ -469,7 +469,7 @@ Results [1]: [max#x]

(10) Exchange
Input [1]: [max#x]
Arguments: SinglePartition, true, [id=#x]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]

(11) HashAggregate [codegen id : 2]
Input [1]: [max#x]
Expand Down Expand Up @@ -516,7 +516,7 @@ Results [1]: [max#x]

(17) Exchange
Input [1]: [max#x]
Arguments: SinglePartition, true, [id=#x]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]

(18) HashAggregate [codegen id : 2]
Input [1]: [max#x]
Expand Down Expand Up @@ -600,7 +600,7 @@ Results [1]: [max#x]

(9) Exchange
Input [1]: [max#x]
Arguments: SinglePartition, true, [id=#x]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]

(10) HashAggregate [codegen id : 2]
Input [1]: [max#x]
Expand Down Expand Up @@ -647,7 +647,7 @@ Results [2]: [sum#x, count#xL]

(16) Exchange
Input [2]: [sum#x, count#xL]
Arguments: SinglePartition, true, [id=#x]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]

(17) HashAggregate [codegen id : 2]
Input [2]: [sum#x, count#xL]
Expand Down Expand Up @@ -713,7 +713,7 @@ Results [2]: [sum#x, count#xL]

(7) Exchange
Input [2]: [sum#x, count#xL]
Arguments: SinglePartition, true, [id=#x]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]

(8) HashAggregate [codegen id : 2]
Input [2]: [sum#x, count#xL]
Expand Down Expand Up @@ -851,7 +851,7 @@ Results [2]: [key#x, max#x]

(6) Exchange
Input [2]: [key#x, max#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]

(7) HashAggregate [codegen id : 4]
Input [2]: [key#x, max#x]
Expand Down Expand Up @@ -943,7 +943,7 @@ Results [3]: [count#xL, sum#xL, count#xL]

(4) Exchange
Input [3]: [count#xL, sum#xL, count#xL]
Arguments: SinglePartition, true, [id=#x]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]

(5) HashAggregate [codegen id : 2]
Input [3]: [count#xL, sum#xL, count#xL]
Expand Down Expand Up @@ -988,7 +988,7 @@ Results [2]: [key#x, buf#x]

(4) Exchange
Input [2]: [key#x, buf#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]

(5) ObjectHashAggregate
Input [2]: [key#x, buf#x]
Expand Down Expand Up @@ -1039,7 +1039,7 @@ Results [2]: [key#x, min#x]

(5) Exchange
Input [2]: [key#x, min#x]
Arguments: hashpartitioning(key#x, 4), true, [id=#x]
Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x]

(6) Sort [codegen id : 2]
Input [2]: [key#x, min#x]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE
Expand Down Expand Up @@ -763,7 +763,9 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] {
case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleExchangeLike {
override def numMappers: Int = delegate.numMappers
override def numPartitions: Int = delegate.numPartitions
override def canChangeNumPartitions: Boolean = delegate.canChangeNumPartitions
override def shuffleOrigin: ShuffleOrigin = {
delegate.shuffleOrigin
}
override def mapOutputStatisticsFuture: Future[MapOutputStatistics] =
delegate.mapOutputStatisticsFuture
override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1013,4 +1013,14 @@ class AdaptiveQueryExecSuite
}
}
}

test("SPARK-33494: Do not use local shuffle reader for repartition") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
val df = spark.table("testData").repartition('key)
df.collect()
// local shuffle reader breaks partitioning and shouldn't be used for repartition operation
// which is specified by users.
checkNumLocalShuffleReaders(df.queryExecution.executedPlan, numShufflesWithoutLocalReader = 1)
}
}
}

0 comments on commit df83fd2

Please sign in to comment.