Skip to content

Commit a2aa966

Browse files
JkSelfcloud-fan
andcommitted
[SPARK-29544][SQL] optimize skewed partition based on data size
### What changes were proposed in this pull request? Skew Join is common and can severely downgrade performance of queries, especially those with joins. This PR aim to optimization the skew join based on the runtime Map output statistics by adding "OptimizeSkewedPartitions" rule. And The details design doc is [here](https://docs.google.com/document/d/1NkXN-ck8jUOS0COz3f8LUW5xzF8j9HFjoZXWGGX2HAg/edit). Currently we can support "Inner, Cross, LeftSemi, LeftAnti, LeftOuter, RightOuter" join type. ### Why are the changes needed? To optimize the skewed partition in runtime based on AQE ### Does this PR introduce any user-facing change? No ### How was this patch tested? UT Closes #26434 from JkSelf/skewedPartitionBasedSize. Lead-authored-by: jiake <ke.a.jia@intel.com> Co-authored-by: Wenchen Fan <wenchen@databricks.com> Co-authored-by: JiaKe <ke.a.jia@intel.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 2688fae commit a2aa966

File tree

14 files changed

+703
-106
lines changed

14 files changed

+703
-106
lines changed

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -343,15 +343,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
343343
/**
344344
* Called from executors to get the server URIs and output sizes for each shuffle block that
345345
* needs to be read from a given range of map output partitions (startPartition is included but
346-
* endPartition is excluded from the range) and is produced by a specific mapper.
346+
* endPartition is excluded from the range) and is produced by
347+
* a range of mappers (startMapIndex, endMapIndex, startMapIndex is included and
348+
* the endMapIndex is excluded).
347349
*
348350
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
349351
* and the second item is a sequence of (shuffle block id, shuffle block size, map index)
350352
* tuples describing the shuffle blocks that are stored at that block manager.
351353
*/
352-
def getMapSizesByMapIndex(
354+
def getMapSizesByRange(
353355
shuffleId: Int,
354-
mapIndex: Int,
356+
startMapIndex: Int,
357+
endMapIndex: Int,
355358
startPartition: Int,
356359
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
357360

@@ -688,20 +691,25 @@ private[spark] class MapOutputTrackerMaster(
688691
}
689692

690693
/**
691-
* Return the location where the Mapper ran. The locations each includes both a host and an
694+
* Return the locations where the Mappers ran. The locations each includes both a host and an
692695
* executor id on that host.
693696
*
694697
* @param dep shuffle dependency object
695-
* @param mapId the map id
698+
* @param startMapIndex the start map index
699+
* @param endMapIndex the end map index
696700
* @return a sequence of locations where task runs.
697701
*/
698-
def getMapLocation(dep: ShuffleDependency[_, _, _], mapId: Int): Seq[String] =
702+
def getMapLocation(
703+
dep: ShuffleDependency[_, _, _],
704+
startMapIndex: Int,
705+
endMapIndex: Int): Seq[String] =
699706
{
700707
val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull
701708
if (shuffleStatus != null) {
702709
shuffleStatus.withMapStatuses { statuses =>
703-
if (mapId >= 0 && mapId < statuses.length) {
704-
Seq(statuses(mapId).location.host)
710+
if (startMapIndex < endMapIndex && (startMapIndex >= 0 && endMapIndex < statuses.length)) {
711+
val statusesPicked = statuses.slice(startMapIndex, endMapIndex).filter(_ != null)
712+
statusesPicked.map(_.location.host).toSeq
705713
} else {
706714
Nil
707715
}
@@ -737,29 +745,26 @@ private[spark] class MapOutputTrackerMaster(
737745
case Some (shuffleStatus) =>
738746
shuffleStatus.withMapStatuses { statuses =>
739747
MapOutputTracker.convertMapStatuses(
740-
shuffleId, startPartition, endPartition, statuses)
748+
shuffleId, startPartition, endPartition, statuses, 0, shuffleStatus.mapStatuses.length)
741749
}
742750
case None =>
743751
Iterator.empty
744752
}
745753
}
746754

747-
override def getMapSizesByMapIndex(
755+
override def getMapSizesByRange(
748756
shuffleId: Int,
749-
mapIndex: Int,
757+
startMapIndex: Int,
758+
endMapIndex: Int,
750759
startPartition: Int,
751760
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
752-
logDebug(s"Fetching outputs for shuffle $shuffleId, mapIndex $mapIndex" +
761+
logDebug(s"Fetching outputs for shuffle $shuffleId, mappers $startMapIndex-$endMapIndex" +
753762
s"partitions $startPartition-$endPartition")
754763
shuffleStatuses.get(shuffleId) match {
755-
case Some (shuffleStatus) =>
764+
case Some(shuffleStatus) =>
756765
shuffleStatus.withMapStatuses { statuses =>
757766
MapOutputTracker.convertMapStatuses(
758-
shuffleId,
759-
startPartition,
760-
endPartition,
761-
statuses,
762-
Some(mapIndex))
767+
shuffleId, startPartition, endPartition, statuses, startMapIndex, endMapIndex)
763768
}
764769
case None =>
765770
Iterator.empty
@@ -802,7 +807,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
802807
val statuses = getStatuses(shuffleId, conf)
803808
try {
804809
MapOutputTracker.convertMapStatuses(
805-
shuffleId, startPartition, endPartition, statuses)
810+
shuffleId, startPartition, endPartition, statuses, 0, statuses.length)
806811
} catch {
807812
case e: MetadataFetchFailedException =>
808813
// We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
@@ -811,17 +816,18 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
811816
}
812817
}
813818

814-
override def getMapSizesByMapIndex(
819+
override def getMapSizesByRange(
815820
shuffleId: Int,
816-
mapIndex: Int,
821+
startMapIndex: Int,
822+
endMapIndex: Int,
817823
startPartition: Int,
818824
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
819-
logDebug(s"Fetching outputs for shuffle $shuffleId, mapIndex $mapIndex" +
825+
logDebug(s"Fetching outputs for shuffle $shuffleId, mappers $startMapIndex-$endMapIndex" +
820826
s"partitions $startPartition-$endPartition")
821827
val statuses = getStatuses(shuffleId, conf)
822828
try {
823-
MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition,
824-
statuses, Some(mapIndex))
829+
MapOutputTracker.convertMapStatuses(
830+
shuffleId, startPartition, endPartition, statuses, startMapIndex, endMapIndex)
825831
} catch {
826832
case e: MetadataFetchFailedException =>
827833
// We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
@@ -980,7 +986,8 @@ private[spark] object MapOutputTracker extends Logging {
980986
* @param startPartition Start of map output partition ID range (included in range)
981987
* @param endPartition End of map output partition ID range (excluded from range)
982988
* @param statuses List of map statuses, indexed by map partition index.
983-
* @param mapIndex When specified, only shuffle blocks from this mapper will be processed.
989+
* @param startMapIndex Start Map index.
990+
* @param endMapIndex End Map index.
984991
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
985992
* and the second item is a sequence of (shuffle block id, shuffle block size, map index)
986993
* tuples describing the shuffle blocks that are stored at that block manager.
@@ -990,11 +997,12 @@ private[spark] object MapOutputTracker extends Logging {
990997
startPartition: Int,
991998
endPartition: Int,
992999
statuses: Array[MapStatus],
993-
mapIndex : Option[Int] = None): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
1000+
startMapIndex : Int,
1001+
endMapIndex: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
9941002
assert (statuses != null)
9951003
val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
9961004
val iter = statuses.iterator.zipWithIndex
997-
for ((status, mapIndex) <- mapIndex.map(index => iter.filter(_._2 == index)).getOrElse(iter)) {
1005+
for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) {
9981006
if (status == null) {
9991007
val errorMessage = s"Missing an output location for shuffle $shuffleId"
10001008
logError(errorMessage)

core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,14 @@ private[spark] trait ShuffleManager {
5555
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C]
5656

5757
/**
58-
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive)
59-
* that are produced by one specific mapper. Called on executors by reduce tasks.
58+
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to
59+
* read from map output (startMapIndex to endMapIndex - 1, inclusive).
60+
* Called on executors by reduce tasks.
6061
*/
61-
def getReaderForOneMapper[K, C](
62+
def getReaderForRange[K, C](
6263
handle: ShuffleHandle,
63-
mapIndex: Int,
64+
startMapIndex: Int,
65+
endMapIndex: Int,
6466
startPartition: Int,
6567
endPartition: Int,
6668
context: TaskContext,

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,16 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
131131
shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context))
132132
}
133133

134-
override def getReaderForOneMapper[K, C](
134+
override def getReaderForRange[K, C](
135135
handle: ShuffleHandle,
136-
mapIndex: Int,
136+
startMapIndex: Int,
137+
endMapIndex: Int,
137138
startPartition: Int,
138139
endPartition: Int,
139140
context: TaskContext,
140141
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
141-
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByMapIndex(
142-
handle.shuffleId, mapIndex, startPartition, endPartition)
142+
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByRange(
143+
handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition)
143144
new BlockStoreShuffleReader(
144145
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics,
145146
shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context))

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,36 @@ object SQLConf {
422422
.booleanConf
423423
.createWithDefault(true)
424424

425+
val ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED =
426+
buildConf("spark.sql.adaptive.optimizeSkewedJoin.enabled")
427+
.doc("When true and adaptive execution is enabled, a skewed join is automatically handled at " +
428+
"runtime.")
429+
.booleanConf
430+
.createWithDefault(true)
431+
432+
val ADAPTIVE_EXECUTION_SKEWED_PARTITION_SIZE_THRESHOLD =
433+
buildConf("spark.sql.adaptive.optimizeSkewedJoin.skewedPartitionSizeThreshold")
434+
.doc("Configures the minimum size in bytes for a partition that is considered as a skewed " +
435+
"partition in adaptive skewed join.")
436+
.bytesConf(ByteUnit.BYTE)
437+
.createWithDefault(64 * 1024 * 1024)
438+
439+
val ADAPTIVE_EXECUTION_SKEWED_PARTITION_FACTOR =
440+
buildConf("spark.sql.adaptive.optimizeSkewedJoin.skewedPartitionFactor")
441+
.doc("A partition is considered as a skewed partition if its size is larger than" +
442+
" this factor multiple the median partition size and also larger than " +
443+
s" ${ADAPTIVE_EXECUTION_SKEWED_PARTITION_SIZE_THRESHOLD.key}")
444+
.intConf
445+
.createWithDefault(10)
446+
447+
val ADAPTIVE_EXECUTION_SKEWED_PARTITION_MAX_SPLITS =
448+
buildConf("spark.sql.adaptive.optimizeSkewedJoin.skewedPartitionMaxSplits")
449+
.doc("Configures the maximum number of task to handle a skewed partition in adaptive skewed" +
450+
"join.")
451+
.intConf
452+
.checkValue( _ >= 1, "The split size at least be 1")
453+
.createWithDefault(5)
454+
425455
val NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN =
426456
buildConf("spark.sql.adaptive.nonEmptyPartitionRatioForBroadcastJoin")
427457
.doc("The relation with a non-empty partition ratio lower than this config will not be " +

sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: A
116116
class ShuffledRowRDD(
117117
var dependency: ShuffleDependency[Int, InternalRow, InternalRow],
118118
metrics: Map[String, SQLMetric],
119-
specifiedPartitionStartIndices: Option[Array[Int]] = None)
119+
specifiedPartitionIndices: Option[Array[(Int, Int)]] = None)
120120
extends RDD[InternalRow](dependency.rdd.context, Nil) {
121121

122122
if (SQLConf.get.fetchShuffleBlocksInBatchEnabled) {
@@ -126,8 +126,8 @@ class ShuffledRowRDD(
126126

127127
private[this] val numPreShufflePartitions = dependency.partitioner.numPartitions
128128

129-
private[this] val partitionStartIndices: Array[Int] = specifiedPartitionStartIndices match {
130-
case Some(indices) => indices
129+
private[this] val partitionStartIndices: Array[Int] = specifiedPartitionIndices match {
130+
case Some(indices) => indices.map(_._1)
131131
case None =>
132132
// When specifiedPartitionStartIndices is not defined, every post-shuffle partition
133133
// corresponds to a pre-shuffle partition.
@@ -142,16 +142,15 @@ class ShuffledRowRDD(
142142
override val partitioner: Option[Partitioner] = Some(part)
143143

144144
override def getPartitions: Array[Partition] = {
145-
assert(partitionStartIndices.length == part.numPartitions)
146-
Array.tabulate[Partition](partitionStartIndices.length) { i =>
147-
val startIndex = partitionStartIndices(i)
148-
val endIndex =
149-
if (i < partitionStartIndices.length - 1) {
150-
partitionStartIndices(i + 1)
151-
} else {
152-
numPreShufflePartitions
145+
specifiedPartitionIndices match {
146+
case Some(indices) =>
147+
Array.tabulate[Partition](indices.length) { i =>
148+
new ShuffledRowRDDPartition(i, indices(i)._1, indices(i)._2)
149+
}
150+
case None =>
151+
Array.tabulate[Partition](numPreShufflePartitions) { i =>
152+
new ShuffledRowRDDPartition(i, i, i + 1)
153153
}
154-
new ShuffledRowRDDPartition(i, startIndex, endIndex)
155154
}
156155
}
157156

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ case class AdaptiveSparkPlanExec(
8787
// optimizations should be stage-independent.
8888
@transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq(
8989
ReuseAdaptiveSubquery(conf, context.subqueryCache),
90+
// Here the 'OptimizeSkewedPartitions' rule should be executed
91+
// before 'ReduceNumShufflePartitions', as the skewed partition handled
92+
// in 'OptimizeSkewedPartitions' rule, should be omitted in 'ReduceNumShufflePartitions'.
93+
OptimizeSkewedJoin(conf),
9094
ReduceNumShufflePartitions(conf),
9195
// The rule of 'OptimizeLocalShuffleReader' need to make use of the 'partitionStartIndices'
9296
// in 'ReduceNumShufflePartitions' rule. So it must be after 'ReduceNumShufflePartitions' rule.

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class LocalShuffledRowRDD(
8282

8383
override def getPreferredLocations(partition: Partition): Seq[String] = {
8484
val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
85-
tracker.getMapLocation(dependency, partition.index)
85+
tracker.getMapLocation(dependency, partition.index, partition.index + 1)
8686
}
8787

8888
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
@@ -92,9 +92,11 @@ class LocalShuffledRowRDD(
9292
// `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator,
9393
// as well as the `tempMetrics` for basic shuffle metrics.
9494
val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics)
95-
val reader = SparkEnv.get.shuffleManager.getReaderForOneMapper(
95+
96+
val reader = SparkEnv.get.shuffleManager.getReaderForRange(
9697
dependency.shuffleHandle,
9798
mapIndex,
99+
mapIndex + 1,
98100
localRowPartition.startPartition,
99101
localRowPartition.endPartition,
100102
context,

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {
7171
plan match {
7272
case c @ CoalescedShuffleReaderExec(s: ShuffleQueryStageExec, _) =>
7373
LocalShuffleReaderExec(
74-
s, getPartitionStartIndices(s, Some(c.partitionStartIndices.length)))
74+
s, getPartitionStartIndices(s, Some(c.partitionIndices.length)))
7575
case s: ShuffleQueryStageExec =>
7676
LocalShuffleReaderExec(s, getPartitionStartIndices(s, None))
7777
}

0 commit comments

Comments
 (0)