Skip to content

Commit e966284

Browse files
mateizrxin
authored andcommitted
SPARK-2045 Sort-based shuffle
This adds a new ShuffleManager based on sorting, as described in https://issues.apache.org/jira/browse/SPARK-2045. The bulk of the code is in an ExternalSorter class that is similar to ExternalAppendOnlyMap, but sorts key-value pairs by partition ID and can be used to create a single sorted file with a map task's output. (Longer-term I think this can take on the remaining functionality in ExternalAppendOnlyMap and replace it so we don't have code duplication.) The main TODOs still left are: - [x] enabling ExternalSorter to merge across spilled files - [x] with an Ordering - [x] without an Ordering, using the keys' hash codes - [x] adding more tests (e.g. a version of our shuffle suite that runs on this) - [x] rebasing on top of the size-tracking refactoring in #1165 when that is merged - [x] disabling spilling if spark.shuffle.spill is set to false Despite this though, this seems to work pretty well (running successfully in cases where the hash shuffle would OOM, such as 1000 reduce tasks on executors with only 1G memory), and it seems to be comparable in speed or faster than hash-based shuffle (it will create much fewer files for the OS to keep track of). So I'm posting it to get some early feedback. After these TODOs are done, I'd also like to enable ExternalSorter to sort data within each partition by a key as well, which will allow us to use it to implement external spilling in reduce tasks in `sortByKey`. Author: Matei Zaharia <matei@databricks.com> Closes #1499 from mateiz/sort-based-shuffle and squashes the following commits: bd841f9 [Matei Zaharia] Various review comments d1c137f [Matei Zaharia] Various review comments a611159 [Matei Zaharia] Compile fixes due to rebase 62c56c8 [Matei Zaharia] Fix ShuffledRDD sometimes not returning Tuple2s. f617432 [Matei Zaharia] Fix a failing test (seems to be due to change in SizeTracker logic) 9464d5f [Matei Zaharia] Simplify code and fix conflicts after latest rebase 0174149 [Matei Zaharia] Add cleanup behavior and cleanup tests for sort-based shuffle eb4ee0d [Matei Zaharia] Remove customizable element type in ShuffledRDD fa2e8db [Matei Zaharia] Allow nextBatchStream to be called after we're done looking at all streams a34b352 [Matei Zaharia] Fix tracking of indices within a partition in SpillReader, and add test 03e1006 [Matei Zaharia] Add a SortShuffleSuite that runs ShuffleSuite with sort-based shuffle 3c7ff1f [Matei Zaharia] Obey the spark.shuffle.spill setting in ExternalSorter ad65fbd [Matei Zaharia] Rebase on top of Aaron's Sorter change, and use Sorter in our buffer 44d2a93 [Matei Zaharia] Use estimateSize instead of atGrowThreshold to test collection sizes 5686f71 [Matei Zaharia] Optimize merging phase for in-memory only data: 5461cbb [Matei Zaharia] Review comments and more tests (e.g. tests with 1 element per partition) e9ad356 [Matei Zaharia] Update ContextCleanerSuite to make sure shuffle cleanup tests use hash shuffle (since they were written for it) c72362a [Matei Zaharia] Added bug fix and test for when iterators are empty de1fb40 [Matei Zaharia] Make trait SizeTrackingCollection private[spark] 4988d16 [Matei Zaharia] tweak c1b7572 [Matei Zaharia] Small optimization ba7db7f [Matei Zaharia] Handle null keys in hash-based comparator, and add tests for collisions ef4e397 [Matei Zaharia] Support for partial aggregation even without an Ordering 4b7a5ce [Matei Zaharia] More tests, and ability to sort data if a total ordering is given e1f84be [Matei Zaharia] Fix disk block manager test 5a40a1c [Matei Zaharia] More tests 614f1b4 [Matei Zaharia] Add spill metrics to map tasks cc52caf [Matei Zaharia] Add more error handling and tests for error cases bbf359d [Matei Zaharia] More work 3a56341 [Matei Zaharia] More partial work towards sort-based shuffle 7a0895d [Matei Zaharia] Some more partial work towards sort-based shuffle b615476 [Matei Zaharia] Scaffolding for sort-based shuffle
1 parent da50176 commit e966284

35 files changed

+1969
-159
lines changed

core/src/main/scala/org/apache/spark/Aggregator.scala

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,23 @@ case class Aggregator[K, V, C] (
5656
} else {
5757
val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
5858
combiners.insertAll(iter)
59-
// TODO: Make this non optional in a future release
60-
Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled)
61-
Option(context).foreach(c => c.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled)
59+
// Update task metrics if context is not null
60+
// TODO: Make context non optional in a future release
61+
Option(context).foreach { c =>
62+
c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled
63+
c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled
64+
}
6265
combiners.iterator
6366
}
6467
}
6568

6669
@deprecated("use combineCombinersByKey with TaskContext argument", "0.9.0")
67-
def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] =
70+
def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]]) : Iterator[(K, C)] =
6871
combineCombinersByKey(iter, null)
6972

70-
def combineCombinersByKey(iter: Iterator[(K, C)], context: TaskContext) : Iterator[(K, C)] = {
73+
def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext)
74+
: Iterator[(K, C)] =
75+
{
7176
if (!externalSorting) {
7277
val combiners = new AppendOnlyMap[K,C]
7378
var kc: Product2[K, C] = null
@@ -85,9 +90,12 @@ case class Aggregator[K, V, C] (
8590
val pair = iter.next()
8691
combiners.insert(pair._1, pair._2)
8792
}
88-
// TODO: Make this non optional in a future release
89-
Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled)
90-
Option(context).foreach(c => c.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled)
93+
// Update task metrics if context is not null
94+
// TODO: Make context non-optional in a future release
95+
Option(context).foreach { c =>
96+
c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled
97+
c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled
98+
}
9199
combiners.iterator
92100
}
93101
}

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ class SparkContext(config: SparkConf) extends Logging {
289289
value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} {
290290
executorEnvs(envKey) = value
291291
}
292-
Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v =>
292+
Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v =>
293293
executorEnvs("SPARK_PREPEND_CLASSES") = v
294294
}
295295
// The Mesos scheduler backend relies on this environment variable to set executor memory.
@@ -1203,10 +1203,10 @@ class SparkContext(config: SparkConf) extends Logging {
12031203
/**
12041204
* Clean a closure to make it ready to serialized and send to tasks
12051205
* (removes unreferenced variables in $outer's, updates REPL variables)
1206-
* If <tt>checkSerializable</tt> is set, <tt>clean</tt> will also proactively
1207-
* check to see if <tt>f</tt> is serializable and throw a <tt>SparkException</tt>
1206+
* If <tt>checkSerializable</tt> is set, <tt>clean</tt> will also proactively
1207+
* check to see if <tt>f</tt> is serializable and throw a <tt>SparkException</tt>
12081208
* if not.
1209-
*
1209+
*
12101210
* @param f the closure to clean
12111211
* @param checkSerializable whether or not to immediately check <tt>f</tt> for serializability
12121212
* @throws <tt>SparkException<tt> if <tt>checkSerializable</tt> is set but <tt>f</tt> is not

core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
122122
*/
123123
def sample(withReplacement: Boolean, fraction: Double): JavaPairRDD[K, V] =
124124
sample(withReplacement, fraction, Utils.random.nextLong)
125-
125+
126126
/**
127127
* Return a sampled subset of this RDD.
128128
*/

core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717

1818
package org.apache.spark.rdd
1919

20+
import scala.language.existentials
21+
2022
import java.io.{IOException, ObjectOutputStream}
2123

2224
import scala.collection.mutable.ArrayBuffer
23-
import scala.language.existentials
2425

2526
import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
2627
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
@@ -157,8 +158,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
157158
for ((it, depNum) <- rddIterators) {
158159
map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum))))
159160
}
160-
context.taskMetrics.memoryBytesSpilled = map.memoryBytesSpilled
161-
context.taskMetrics.diskBytesSpilled = map.diskBytesSpilled
161+
context.taskMetrics.memoryBytesSpilled += map.memoryBytesSpilled
162+
context.taskMetrics.diskBytesSpilled += map.diskBytesSpilled
162163
new InterruptibleIterator(context,
163164
map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
164165
}

core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.rdd
2020
import scala.reflect.ClassTag
2121

2222
import org.apache.spark.{Logging, RangePartitioner}
23+
import org.apache.spark.annotation.DeveloperApi
2324

2425
/**
2526
* Extra functions available on RDDs of (key, value) pairs where the key is sortable through
@@ -43,10 +44,10 @@ import org.apache.spark.{Logging, RangePartitioner}
4344
*/
4445
class OrderedRDDFunctions[K : Ordering : ClassTag,
4546
V: ClassTag,
46-
P <: Product2[K, V] : ClassTag](
47+
P <: Product2[K, V] : ClassTag] @DeveloperApi() (
4748
self: RDD[P])
48-
extends Logging with Serializable {
49-
49+
extends Logging with Serializable
50+
{
5051
private val ordering = implicitly[Ordering[K]]
5152

5253
/**
@@ -55,9 +56,12 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
5556
* (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
5657
* order of the keys).
5758
*/
58-
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = {
59+
// TODO: this currently doesn't work on P other than Tuple2!
60+
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size)
61+
: RDD[(K, V)] =
62+
{
5963
val part = new RangePartitioner(numPartitions, self, ascending)
60-
new ShuffledRDD[K, V, V, P](self, part)
64+
new ShuffledRDD[K, V, V](self, part)
6165
.setKeyOrdering(if (ascending) ordering else ordering.reverse)
6266
}
6367
}

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
9090
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
9191
}, preservesPartitioning = true)
9292
} else {
93-
new ShuffledRDD[K, V, C, (K, C)](self, partitioner)
93+
new ShuffledRDD[K, V, C](self, partitioner)
9494
.setSerializer(serializer)
9595
.setAggregator(aggregator)
9696
.setMapSideCombine(mapSideCombine)
@@ -425,7 +425,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
425425
if (self.partitioner == Some(partitioner)) {
426426
self
427427
} else {
428-
new ShuffledRDD[K, V, V, (K, V)](self, partitioner)
428+
new ShuffledRDD[K, V, V](self, partitioner)
429429
}
430430
}
431431

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ abstract class RDD[T: ClassTag](
332332
val distributePartition = (index: Int, items: Iterator[T]) => {
333333
var position = (new Random(index)).nextInt(numPartitions)
334334
items.map { t =>
335-
// Note that the hash code of the key will just be the key itself. The HashPartitioner
335+
// Note that the hash code of the key will just be the key itself. The HashPartitioner
336336
// will mod it with the number of total partitions.
337337
position = position + 1
338338
(position, t)
@@ -341,7 +341,7 @@ abstract class RDD[T: ClassTag](
341341

342342
// include a shuffle step so that our upstream tasks are still distributed
343343
new CoalescedRDD(
344-
new ShuffledRDD[Int, T, T, (Int, T)](mapPartitionsWithIndex(distributePartition),
344+
new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition),
345345
new HashPartitioner(numPartitions)),
346346
numPartitions).values
347347
} else {
@@ -352,8 +352,8 @@ abstract class RDD[T: ClassTag](
352352
/**
353353
* Return a sampled subset of this RDD.
354354
*/
355-
def sample(withReplacement: Boolean,
356-
fraction: Double,
355+
def sample(withReplacement: Boolean,
356+
fraction: Double,
357357
seed: Long = Utils.random.nextLong): RDD[T] = {
358358
require(fraction >= 0.0, "Negative fraction value: " + fraction)
359359
if (withReplacement) {

core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,12 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
3737
* @tparam V the value class.
3838
* @tparam C the combiner class.
3939
*/
40+
// TODO: Make this return RDD[Product2[K, C]] or have some way to configure mutable pairs
4041
@DeveloperApi
41-
class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
42+
class ShuffledRDD[K, V, C](
4243
@transient var prev: RDD[_ <: Product2[K, V]],
4344
part: Partitioner)
44-
extends RDD[P](prev.context, Nil) {
45+
extends RDD[(K, C)](prev.context, Nil) {
4546

4647
private var serializer: Option[Serializer] = None
4748

@@ -52,25 +53,25 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
5253
private var mapSideCombine: Boolean = false
5354

5455
/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
55-
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = {
56+
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C] = {
5657
this.serializer = Option(serializer)
5758
this
5859
}
5960

6061
/** Set key ordering for RDD's shuffle. */
61-
def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C, P] = {
62+
def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C] = {
6263
this.keyOrdering = Option(keyOrdering)
6364
this
6465
}
6566

6667
/** Set aggregator for RDD's shuffle. */
67-
def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C, P] = {
68+
def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C] = {
6869
this.aggregator = Option(aggregator)
6970
this
7071
}
7172

7273
/** Set mapSideCombine flag for RDD's shuffle. */
73-
def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C, P] = {
74+
def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C] = {
7475
this.mapSideCombine = mapSideCombine
7576
this
7677
}
@@ -85,11 +86,11 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
8586
Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i))
8687
}
8788

88-
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
89+
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
8990
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
9091
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
9192
.read()
92-
.asInstanceOf[Iterator[P]]
93+
.asInstanceOf[Iterator[(K, C)]]
9394
}
9495

9596
override def clearDependencies() {

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.shuffle._
2424
* A ShuffleManager using hashing, that creates one output file per reduce partition on each
2525
* mapper (possibly reusing these across waves of tasks).
2626
*/
27-
class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
27+
private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
2828
/* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */
2929
override def registerShuffle[K, V, C](
3030
shuffleId: Int,

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.{InterruptibleIterator, TaskContext}
2121
import org.apache.spark.serializer.Serializer
2222
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
2323

24-
class HashShuffleReader[K, C](
24+
private[spark] class HashShuffleReader[K, C](
2525
handle: BaseShuffleHandle[K, _, C],
2626
startPartition: Int,
2727
endPartition: Int,
@@ -47,7 +47,8 @@ class HashShuffleReader[K, C](
4747
} else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
4848
throw new IllegalStateException("Aggregator is empty for map-side combine")
4949
} else {
50-
iter
50+
// Convert the Product2s to pairs since this is what downstream RDDs currently expect
51+
iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
5152
}
5253

5354
// Sort the output if there is a sort ordering defined.

0 commit comments

Comments
 (0)