Skip to content

Commit c64df90

Browse files
committed
optimize skewed partition based on data size
1 parent 7759f71 commit c64df90

File tree

15 files changed

+769
-45
lines changed

15 files changed

+769
-45
lines changed

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 106 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,21 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
355355
startPartition: Int,
356356
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
357357

358+
/**
359+
* Called from executors to get the server URIs and output sizes for each shuffle block that
360+
* needs to be read from a specific map output partitions (partitionIndex) and is
361+
* produced by a range mapper (startMapId, endMapId)
362+
*
363+
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
364+
* and the second item is a sequence of (shuffle block id, shuffle block size, map index)
365+
* tuples describing the shuffle blocks that are stored at that block manager.
366+
*/
367+
def getMapSizesByRangeMapIndex(
368+
shuffleId: Int,
369+
partitionIndex: Int,
370+
startMapId: Int,
371+
endMapId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
372+
358373
/**
359374
* Deletes map output status information for the specified shuffle stage.
360375
*/
@@ -688,21 +703,28 @@ private[spark] class MapOutputTrackerMaster(
688703
}
689704

690705
/**
691-
* Return the location where the Mapper ran. The locations each includes both a host and an
706+
* Return the locations where the Mappers ran. The locations each includes both a host and an
692707
* executor id on that host.
693708
*
694709
* @param dep shuffle dependency object
695-
* @param mapId the map id
710+
* @param startMapId the start map id
711+
* @param endMapId the end map id
696712
* @return a sequence of locations where task runs.
697713
*/
698-
def getMapLocation(dep: ShuffleDependency[_, _, _], mapId: Int): Seq[String] =
714+
def getMapLocation(
715+
dep: ShuffleDependency[_, _, _],
716+
startMapId: Int,
717+
endMapId: Int): Seq[String] =
699718
{
700719
val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull
701720
if (shuffleStatus != null) {
702721
shuffleStatus.withMapStatuses { statuses =>
703-
if (mapId >= 0 && mapId < statuses.length) {
704-
Seq( ExecutorCacheTaskLocation(statuses(mapId).location.host,
705-
statuses(mapId).location.executorId).toString)
722+
if (startMapId < endMapId && (startMapId >= 0 && endMapId < statuses.length)) {
723+
val statusesPicked = statuses.slice(startMapId, endMapId).filter(_ != null)
724+
statusesPicked.map { status =>
725+
ExecutorCacheTaskLocation(status.location.host,
726+
status.location.executorId).toString
727+
}.toSeq
706728
} else {
707729
Nil
708730
}
@@ -767,6 +789,22 @@ private[spark] class MapOutputTrackerMaster(
767789
}
768790
}
769791

792+
override def getMapSizesByRangeMapIndex(
793+
shuffleId: Int,
794+
partitionIndex: Int,
795+
startMapId: Int,
796+
endMapId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
797+
shuffleStatuses.get(shuffleId) match {
798+
case Some(shuffleStatus) =>
799+
shuffleStatus.withMapStatuses { statuses =>
800+
MapOutputTracker.convertMapStatuses(
801+
shuffleId, partitionIndex, statuses, startMapId, endMapId)
802+
}
803+
case None =>
804+
Iterator.empty
805+
}
806+
}
807+
770808
override def stop(): Unit = {
771809
mapOutputRequests.offer(PoisonPill)
772810
threadpool.shutdown()
@@ -831,6 +869,22 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
831869
}
832870
}
833871

872+
override def getMapSizesByRangeMapIndex(
873+
shuffleId: Int,
874+
partitionIndex: Int,
875+
startMapId: Int,
876+
endMapId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
877+
val statuses = getStatuses(shuffleId, conf)
878+
try {
879+
MapOutputTracker.convertMapStatuses(shuffleId, partitionIndex, statuses, startMapId, endMapId)
880+
} catch {
881+
case e: MetadataFetchFailedException =>
882+
// We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
883+
mapStatuses.clear()
884+
throw e
885+
}
886+
}
887+
834888
/**
835889
* Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize
836890
* on this array when reading it, because on the driver, we may be changing it in place.
@@ -1013,4 +1067,50 @@ private[spark] object MapOutputTracker extends Logging {
10131067

10141068
splitsByAddress.iterator
10151069
}
1070+
1071+
/**
1072+
* Given an array of map statuses, a specific map output partitions and a range
1073+
* mappers (startMapId, endMapId),returns a sequence that, for each block manager ID,
1074+
* lists the shuffle block IDs and corresponding shuffle
1075+
* block sizes stored at that block manager.
1076+
* Note that empty blocks are filtered in the result.
1077+
*
1078+
* If any of the statuses is null (indicating a missing location due to a failed mapper),
1079+
* throws a FetchFailedException.
1080+
*
1081+
* @param shuffleId Identifier for the shuffle
1082+
* @param partitionIndex Specific of map output partition ID
1083+
* @param statuses List of map statuses, indexed by map partition index.
1084+
* @param startMapId Start Map ID
1085+
* @param endMapId End map ID
1086+
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
1087+
* and the second item is a sequence of (shuffle block id, shuffle block size, map index)
1088+
* tuples describing the shuffle blocks that are stored at that block manager.
1089+
*/
1090+
def convertMapStatuses(
1091+
shuffleId: Int,
1092+
partitionIndex: Int,
1093+
statuses: Array[MapStatus],
1094+
startMapId: Int,
1095+
endMapId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
1096+
assert (statuses != null)
1097+
val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
1098+
val iter = statuses.iterator.zipWithIndex
1099+
for ((status, mapIndex) <- iter.slice(startMapId, endMapId)) {
1100+
if (status == null) {
1101+
val errorMessage = s"Missing an output location for shuffle $shuffleId"
1102+
logError(errorMessage)
1103+
throw new MetadataFetchFailedException(shuffleId, partitionIndex, errorMessage)
1104+
} else {
1105+
val size = status.getSizeForBlock(partitionIndex)
1106+
if (size != 0) {
1107+
splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
1108+
((ShuffleBlockId(shuffleId, status.mapId, partitionIndex), size, mapIndex))
1109+
}
1110+
}
1111+
}
1112+
1113+
splitsByAddress.iterator
1114+
}
1115+
10161116
}

core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,18 @@ private[spark] trait ShuffleManager {
6666
context: TaskContext,
6767
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C]
6868

69+
/**
70+
* Get a reader for the specific partitionIndex in map output statistics that are
71+
* produced by range mappers. Called on executors by reduce tasks.
72+
*/
73+
def getReaderForRangeMapper[K, C](
74+
handle: ShuffleHandle,
75+
partitionIndex: Int,
76+
startMapId: Int,
77+
endMapId: Int,
78+
context: TaskContext,
79+
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C]
80+
6981
/**
7082
* Remove a shuffle's metadata from the ShuffleManager.
7183
* @return true if the metadata removed successfully, otherwise false.

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,20 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
145145
shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context))
146146
}
147147

148+
override def getReaderForRangeMapper[K, C](
149+
handle: ShuffleHandle,
150+
partitionIndex: Int,
151+
startMapId: Int,
152+
endMapId: Int,
153+
context: TaskContext,
154+
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
155+
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByRangeMapIndex(
156+
handle.shuffleId, partitionIndex, startMapId, endMapId)
157+
new BlockStoreShuffleReader(
158+
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics,
159+
shouldBatchFetch = canUseBatchFetch(partitionIndex, partitionIndex + 1, context))
160+
}
161+
148162
/** Get a writer for a given partition. Called on executors by map tasks. */
149163
override def getWriter[K, V](
150164
handle: ShuffleHandle,

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,34 @@ object SQLConf {
402402
.booleanConf
403403
.createWithDefault(true)
404404

405+
val ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED = buildConf("spark.sql.adaptive.skewedJoin.enabled")
406+
.doc("When true and adaptive execution is enabled, a skewed join is automatically handled at " +
407+
"runtime.")
408+
.booleanConf
409+
.createWithDefault(true)
410+
411+
val ADAPTIVE_EXECUTION_SKEWED_PARTITION_FACTOR =
412+
buildConf("spark.sql.adaptive.skewedPartitionFactor")
413+
.doc("A partition is considered as a skewed partition if its size is larger than" +
414+
" this factor multiple the median partition size and also larger than " +
415+
"spark.sql.adaptive.skewedPartitionSizeThreshold.")
416+
.intConf
417+
.createWithDefault(10)
418+
419+
val ADAPTIVE_EXECUTION_SKEWED_PARTITION_SIZE_THRESHOLD =
420+
buildConf("spark.sql.adaptive.skewedPartitionSizeThreshold")
421+
.doc("Configures the minimum size in bytes for a partition that is considered as a skewed " +
422+
"partition in adaptive skewed join.")
423+
.longConf
424+
.createWithDefault(64 * 1024 * 1024L)
425+
426+
val ADAPTIVE_EXECUTION_SKEWED_PARTITION_MAX_SPLITS =
427+
buildConf("spark.sql.adaptive.skewedPartitionMaxSplits")
428+
.doc("Configures the maximum number of task to handle a skewed partition in adaptive skewed" +
429+
"join.")
430+
.intConf
431+
.createWithDefault(5)
432+
405433
val NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN =
406434
buildConf("spark.sql.adaptive.nonEmptyPartitionRatioForBroadcastJoin")
407435
.doc("The relation with a non-empty partition ratio lower than this config will not be " +
@@ -2178,6 +2206,15 @@ class SQLConf extends Serializable with Logging {
21782206
def maxNumPostShufflePartitions: Int =
21792207
getConf(SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS).getOrElse(numShufflePartitions)
21802208

2209+
def adaptiveSkewedJoinEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED)
2210+
2211+
def adaptiveSkewedFactor: Int = getConf(ADAPTIVE_EXECUTION_SKEWED_PARTITION_FACTOR)
2212+
2213+
def adaptiveSkewedSizeThreshold: Long =
2214+
getConf(ADAPTIVE_EXECUTION_SKEWED_PARTITION_SIZE_THRESHOLD)
2215+
2216+
def adaptiveSkewedMaxSplits: Int = getConf(ADAPTIVE_EXECUTION_SKEWED_PARTITION_MAX_SPLITS)
2217+
21812218
def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN)
21822219

21832220
def maxBatchesToRetainInMemory: Int = getConf(MAX_BATCHES_TO_RETAIN_IN_MEMORY)

sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: A
116116
class ShuffledRowRDD(
117117
var dependency: ShuffleDependency[Int, InternalRow, InternalRow],
118118
metrics: Map[String, SQLMetric],
119-
specifiedPartitionStartIndices: Option[Array[Int]] = None)
119+
specifiedPartitionStartIndices: Option[Array[Int]] = None,
120+
specifiedPartitionEndIndices: Option[Array[Int]] = None)
120121
extends RDD[InternalRow](dependency.rdd.context, Nil) {
121122

122123
if (SQLConf.get.fetchShuffleBlocksInBatchEnabled) {
@@ -134,23 +135,24 @@ class ShuffledRowRDD(
134135
(0 until numPreShufflePartitions).toArray
135136
}
136137

137-
private[this] val part: Partitioner =
138-
new CoalescedPartitioner(dependency.partitioner, partitionStartIndices)
139-
140138
override def getDependencies: Seq[Dependency[_]] = List(dependency)
141139

142-
override val partitioner: Option[Partitioner] = Some(part)
140+
override val partitioner: Option[Partitioner] = specifiedPartitionEndIndices match {
141+
case Some(indices) => None
142+
case None => Some(new CoalescedPartitioner(dependency.partitioner, partitionStartIndices))
143+
}
143144

144145
override def getPartitions: Array[Partition] = {
145-
assert(partitionStartIndices.length == part.numPartitions)
146146
Array.tabulate[Partition](partitionStartIndices.length) { i =>
147147
val startIndex = partitionStartIndices(i)
148-
val endIndex =
149-
if (i < partitionStartIndices.length - 1) {
148+
val endIndex = specifiedPartitionEndIndices match {
149+
case Some(indices) => indices(i)
150+
case None => if (i < partitionStartIndices.length - 1) {
150151
partitionStartIndices(i + 1)
151152
} else {
152153
numPreShufflePartitions
153154
}
155+
}
154156
new ShuffledRowRDDPartition(i, startIndex, endIndex)
155157
}
156158
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ case class AdaptiveSparkPlanExec(
9999
// This rule must be executed before `ReduceNumShufflePartitions`, as local shuffle readers
100100
// can't change number of partitions.
101101
OptimizeLocalShuffleReader(conf),
102+
// Here the 'OptimizeSkewedPartitions' rule should be executed
103+
// before 'ReduceNumShufflePartitions', as the skewed partition handled
104+
// in 'OptimizeSkewedPartitions' rule, should be omitted in 'ReduceNumShufflePartitions'.
105+
OptimizeSkewedPartitions(conf),
102106
ReduceNumShufflePartitions(conf),
103107
ApplyColumnarRulesAndInsertTransitions(session.sessionState.conf,
104108
session.sessionState.columnarRules),

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class LocalShuffledRowRDD(
6969

7070
override def getPreferredLocations(partition: Partition): Seq[String] = {
7171
val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
72-
tracker.getMapLocation(dependency, partition.index)
72+
tracker.getMapLocation(dependency, partition.index, partition.index + 1)
7373
}
7474

7575
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {

0 commit comments

Comments
 (0)