Skip to content

Commit 90d084f

Browse files
committed
Add code path to bypass merge-sort in ExternalSorter, and tests
1 parent 31e5d7c commit 90d084f

File tree

3 files changed

+287
-60
lines changed

3 files changed

+287
-60
lines changed

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,25 +58,20 @@ private[spark] class SortShuffleWriter[K, V, C](
5858

5959
/** Write a bunch of records to this task's output */
6060
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
61-
// Get an iterator with the elements for each partition ID
62-
val externalSorter: ExternalSorter[K, V, _] = {
63-
if (dep.mapSideCombine) {
64-
if (!dep.aggregator.isDefined) {
65-
throw new IllegalStateException("Aggregator is empty for map-side combine")
66-
}
67-
val sorter = new ExternalSorter[K, V, C](
68-
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
69-
sorter.insertAll(records)
70-
sorter
71-
} else {
72-
// In this case we pass neither an aggregator nor an ordering to the sorter, because we
73-
// don't care whether the keys get sorted in each partition; that will be done on the
74-
// reduce side if the operation being run is sortByKey.
75-
val sorter = new ExternalSorter[K, V, V](
76-
None, Some(dep.partitioner), None, dep.serializer)
77-
sorter.insertAll(records)
78-
sorter
61+
if (dep.mapSideCombine) {
62+
if (!dep.aggregator.isDefined) {
63+
throw new IllegalStateException("Aggregator is empty for map-side combine")
7964
}
65+
sorter = new ExternalSorter[K, V, C](
66+
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
67+
sorter.insertAll(records)
68+
} else {
69+
// In this case we pass neither an aggregator nor an ordering to the sorter, because we
70+
// don't care whether the keys get sorted in each partition; that will be done on the
71+
// reduce side if the operation being run is sortByKey.
72+
sorter = new ExternalSorter[K, V, V](
73+
None, Some(dep.partitioner), None, dep.serializer)
74+
sorter.insertAll(records)
8075
}
8176

8277
// Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 160 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ import com.google.common.io.ByteStreams
2727

2828
import org.apache.spark._
2929
import org.apache.spark.serializer.{DeserializationStream, Serializer}
30-
import org.apache.spark.storage.BlockId
3130
import org.apache.spark.executor.ShuffleWriteMetrics
31+
import org.apache.spark.storage.{BlockObjectWriter, BlockId}
3232

3333
/**
3434
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -67,6 +67,13 @@ import org.apache.spark.executor.ShuffleWriteMetrics
6767
* for equality to merge values.
6868
*
6969
* - Users are expected to call stop() at the end to delete all the intermediate files.
70+
*
71+
* As a special case, if no Ordering and no Aggregator is given, and the number of partitions is
72+
* less than spark.shuffle.sort.bypassMergeThreshold, we bypass the merge-sort and just write to
73+
* separate files for each partition each time we spill, similar to the HashShuffleWriter. We can
74+
* then concatenate these files to produce a single sorted file, without having to serialize and
75+
* de-serialize each item twice (as is needed during the merge). This speeds up the map side of
76+
* groupBy, sort, etc operations since they do no partial aggregation.
7077
*/
7178
private[spark] class ExternalSorter[K, V, C](
7279
aggregator: Option[Aggregator[K, V, C]] = None,
@@ -124,6 +131,18 @@ private[spark] class ExternalSorter[K, V, C](
124131
// How much of the shared memory pool this collection has claimed
125132
private var myMemoryThreshold = 0L
126133

134+
// If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need
135+
// local aggregation and sorting, write numPartitions files directly and just concatenate them
136+
// at the end. This avoids doing serialization and deserialization twice to merge together the
137+
// spilled files, which would happen with the normal code path. The downside is more small files
138+
// and possibly more I/O if these fall out of the buffer cache.
139+
private val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
140+
private[collection] val bypassMergeSort = // private[collection] for testing
141+
(numPartitions <= bypassMergeThreshold && aggregator.isEmpty && ordering.isEmpty)
142+
143+
// Array of file writers for each partition, used if bypassMergeSort is true
144+
private var partitionWriters: Array[BlockObjectWriter] = null
145+
127146
// A comparator for keys K that orders them within a partition to allow aggregation or sorting.
128147
// Can be a partial ordering by hash code if a total ordering is not provided through by the
129148
// user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some
@@ -137,7 +156,14 @@ private[spark] class ExternalSorter[K, V, C](
137156
}
138157
})
139158

140-
// A comparator for (Int, K) elements that orders them by partition and then possibly by key
159+
// A comparator for (Int, K) pairs that orders them by only their partition ID
160+
private val partitionComparator: Comparator[(Int, K)] = new Comparator[(Int, K)] {
161+
override def compare(a: (Int, K), b: (Int, K)): Int = {
162+
a._1 - b._1
163+
}
164+
}
165+
166+
// A comparator that orders (Int, K) pairs by partition ID and then possibly by key
141167
private val partitionKeyComparator: Comparator[(Int, K)] = {
142168
if (ordering.isDefined || aggregator.isDefined) {
143169
// Sort by partition ID then key comparator
@@ -153,11 +179,7 @@ private[spark] class ExternalSorter[K, V, C](
153179
}
154180
} else {
155181
// Just sort it by partition ID
156-
new Comparator[(Int, K)] {
157-
override def compare(a: (Int, K), b: (Int, K)): Int = {
158-
a._1 - b._1
159-
}
160-
}
182+
partitionComparator
161183
}
162184
}
163185

@@ -242,6 +264,38 @@ private[spark] class ExternalSorter[K, V, C](
242264
val threadId = Thread.currentThread().getId
243265
logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)"
244266
.format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
267+
268+
if (bypassMergeSort) {
269+
spillToPartitionFiles(collection)
270+
} else {
271+
spillToMergeableFile(collection)
272+
}
273+
274+
if (usingMap) {
275+
map = new SizeTrackingAppendOnlyMap[(Int, K), C]
276+
} else {
277+
buffer = new SizeTrackingPairBuffer[(Int, K), C]
278+
}
279+
280+
// Release our memory back to the shuffle pool so that other threads can grab it
281+
shuffleMemoryManager.release(myMemoryThreshold)
282+
myMemoryThreshold = 0
283+
284+
_memoryBytesSpilled += memorySize
285+
}
286+
287+
/**
288+
* Spill our in-memory collection to a sorted file that we can merge later (normal code path).
289+
* We add this file into spilledFiles to find it later.
290+
*
291+
* Alternatively, if bypassMergeSort is true, we spill to separate files for each partition.
292+
* See spillToPartitionedFiles() for that code path.
293+
*
294+
* @param collection whichever collection we're using (map or buffer)
295+
*/
296+
private def spillToMergeableFile(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
297+
assert(!bypassMergeSort)
298+
245299
val (blockId, file) = diskBlockManager.createTempBlock()
246300
curWriteMetrics = new ShuffleWriteMetrics()
247301
var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
@@ -304,18 +358,35 @@ private[spark] class ExternalSorter[K, V, C](
304358
}
305359
}
306360

307-
if (usingMap) {
308-
map = new SizeTrackingAppendOnlyMap[(Int, K), C]
309-
} else {
310-
buffer = new SizeTrackingPairBuffer[(Int, K), C]
311-
}
361+
spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
362+
}
312363

313-
// Release our memory back to the shuffle pool so that other threads can grab it
314-
shuffleMemoryManager.release(myMemoryThreshold)
315-
myMemoryThreshold = 0
364+
/**
365+
* Spill our in-memory collection to separate files, one for each partition. This is used when
366+
* there's no aggregator and ordering and the number of partitions is small, because it allows
367+
* writePartitionedFile to just concatenate files without deserializing data.
368+
*
369+
* @param collection whichever collection we're using (map or buffer)
370+
*/
371+
private def spillToPartitionFiles(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
372+
assert(bypassMergeSort)
373+
374+
// Create our file writers if we haven't done so yet
375+
if (partitionWriters == null) {
376+
partitionWriters = Array.fill(numPartitions) {
377+
val (blockId, file) = diskBlockManager.createTempBlock()
378+
blockManager.getDiskWriter(blockId, file, ser, fileBufferSize).open()
379+
}
380+
}
316381

317-
spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
318-
_memoryBytesSpilled += memorySize
382+
val it = collection.iterator // No need to sort stuff, just write each element out
383+
while (it.hasNext) {
384+
val elem = it.next()
385+
val partitionId = elem._1._1
386+
val key = elem._1._2
387+
val value = elem._2
388+
partitionWriters(partitionId).write((key, value))
389+
}
319390
}
320391

321392
/**
@@ -479,7 +550,6 @@ private[spark] class ExternalSorter[K, V, C](
479550

480551
skipToNextPartition()
481552

482-
483553
// Intermediate file and deserializer streams that read from exactly one batch
484554
// This guards against pre-fetching and other arbitrary behavior of higher level streams
485555
var fileStream: FileInputStream = null
@@ -619,23 +689,25 @@ private[spark] class ExternalSorter[K, V, C](
619689
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
620690
val usingMap = aggregator.isDefined
621691
val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
622-
if (spills.isEmpty) {
692+
if (spills.isEmpty && partitionWriters == null) {
623693
// Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
624694
// we don't even need to sort by anything other than partition ID
625695
if (!ordering.isDefined) {
626-
// The user isn't requested sorted keys, so only sort by partition ID, not key
627-
val partitionComparator = new Comparator[(Int, K)] {
628-
override def compare(a: (Int, K), b: (Int, K)): Int = {
629-
a._1 - b._1
630-
}
631-
}
696+
// The user hasn't requested sorted keys, so only sort by partition ID, not key
632697
groupByPartition(collection.destructiveSortedIterator(partitionComparator))
633698
} else {
634699
// We do need to sort by both partition ID and key
635700
groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator))
636701
}
702+
} else if (bypassMergeSort) {
703+
// Read data from each partition file and merge it together with the data in memory;
704+
// note that there's no ordering or aggregator in this case -- we just partition objects
705+
val collIter = groupByPartition(collection.destructiveSortedIterator(partitionComparator))
706+
collIter.map { case (partitionId, values) =>
707+
(partitionId, values ++ readPartitionFile(partitionWriters(partitionId)))
708+
}
637709
} else {
638-
// General case: merge spilled and in-memory data
710+
// Merge spilled and in-memory data
639711
merge(spills, collection.destructiveSortedIterator(partitionKeyComparator))
640712
}
641713
}
@@ -653,7 +725,7 @@ private[spark] class ExternalSorter[K, V, C](
653725
*
654726
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
655727
* @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
656-
* @return array of lengths, in bytes, for each partition of the file (for map output tracker)
728+
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
657729
*/
658730
def writePartitionedFile(blockId: BlockId, context: TaskContext): Array[Long] = {
659731
val outputFile = blockManager.diskBlockManager.getFile(blockId)
@@ -666,21 +738,51 @@ private[spark] class ExternalSorter[K, V, C](
666738
var totalBytes = 0L
667739
var totalTime = 0L
668740

669-
for ((id, elements) <- this.partitionedIterator) {
670-
if (elements.hasNext) {
671-
val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize)
672-
for (elem <- elements) {
673-
writer.write(elem)
741+
if (bypassMergeSort && partitionWriters != null) {
742+
// We decided to write separate files for each partition, so just concatenate them. To keep
743+
// this simple we spill out the current in-memory collection so that everything is in files.
744+
spillToPartitionFiles(if (aggregator.isDefined) map else buffer)
745+
partitionWriters.foreach(_.commitAndClose())
746+
var out: FileOutputStream = null
747+
var in: FileInputStream = null
748+
try {
749+
out = new FileOutputStream(outputFile)
750+
for (i <- 0 until numPartitions) {
751+
val file = partitionWriters(i).fileSegment().file
752+
in = new FileInputStream(file)
753+
org.apache.spark.util.Utils.copyStream(in, out)
754+
in.close()
755+
in = null
756+
lengths(i) = file.length()
757+
offsets(i + 1) = offsets(i) + lengths(i)
758+
}
759+
} finally {
760+
if (out != null) {
761+
out.close()
762+
}
763+
if (in != null) {
764+
in.close()
765+
}
766+
}
767+
} else {
768+
// Either we're not bypassing merge-sort or we have only in-memory data; get an iterator by
769+
// partition and just write everything directly.
770+
for ((id, elements) <- this.partitionedIterator) {
771+
if (elements.hasNext) {
772+
val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize)
773+
for (elem <- elements) {
774+
writer.write(elem)
775+
}
776+
writer.commitAndClose()
777+
val segment = writer.fileSegment()
778+
offsets(id + 1) = segment.offset + segment.length
779+
lengths(id) = segment.length
780+
totalTime += writer.timeWriting()
781+
totalBytes += segment.length
782+
} else {
783+
// The partition is empty; don't create a new writer to avoid writing headers, etc
784+
offsets(id + 1) = offsets(id)
674785
}
675-
writer.commitAndClose()
676-
val segment = writer.fileSegment()
677-
offsets(id + 1) = segment.offset + segment.length
678-
lengths(id) = segment.length
679-
totalTime += writer.timeWriting()
680-
totalBytes += segment.length
681-
} else {
682-
// The partition is empty; don't create a new writer to avoid writing headers, etc
683-
offsets(id + 1) = offsets(id)
684786
}
685787
}
686788

@@ -711,9 +813,26 @@ private[spark] class ExternalSorter[K, V, C](
711813
lengths
712814
}
713815

816+
/**
817+
* Read a partition file back as an iterator (used in our iterator method)
818+
*/
819+
def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = {
820+
if (writer.isOpen) {
821+
writer.commitAndClose()
822+
}
823+
blockManager.getLocalFromDisk(writer.blockId, ser).get.asInstanceOf[Iterator[Product2[K, C]]]
824+
}
825+
714826
def stop(): Unit = {
715827
spills.foreach(s => s.file.delete())
716828
spills.clear()
829+
if (partitionWriters != null) {
830+
partitionWriters.foreach { w =>
831+
w.revertPartialWritesAndClose()
832+
diskBlockManager.getFile(w.blockId).delete()
833+
}
834+
partitionWriters = null
835+
}
717836
}
718837

719838
def memoryBytesSpilled: Long = _memoryBytesSpilled

0 commit comments

Comments
 (0)