@@ -27,8 +27,8 @@ import com.google.common.io.ByteStreams
27
27
28
28
import org .apache .spark ._
29
29
import org .apache .spark .serializer .{DeserializationStream , Serializer }
30
- import org .apache .spark .storage .BlockId
31
30
import org .apache .spark .executor .ShuffleWriteMetrics
31
+ import org .apache .spark .storage .{BlockObjectWriter , BlockId }
32
32
33
33
/**
34
34
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -67,6 +67,13 @@ import org.apache.spark.executor.ShuffleWriteMetrics
67
67
* for equality to merge values.
68
68
*
69
69
* - Users are expected to call stop() at the end to delete all the intermediate files.
70
+ *
71
+ * As a special case, if no Ordering and no Aggregator is given, and the number of partitions is
72
+ * less than spark.shuffle.sort.bypassMergeThreshold, we bypass the merge-sort and just write to
73
+ * separate files for each partition each time we spill, similar to the HashShuffleWriter. We can
74
+ * then concatenate these files to produce a single sorted file, without having to serialize and
75
+ * de-serialize each item twice (as is needed during the merge). This speeds up the map side of
76
+ * groupBy, sort, etc operations since they do no partial aggregation.
70
77
*/
71
78
private [spark] class ExternalSorter [K , V , C ](
72
79
aggregator : Option [Aggregator [K , V , C ]] = None ,
@@ -124,6 +131,18 @@ private[spark] class ExternalSorter[K, V, C](
124
131
// How much of the shared memory pool this collection has claimed
125
132
private var myMemoryThreshold = 0L
126
133
134
+ // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need
135
+ // local aggregation and sorting, write numPartitions files directly and just concatenate them
136
+ // at the end. This avoids doing serialization and deserialization twice to merge together the
137
+ // spilled files, which would happen with the normal code path. The downside is more small files
138
+ // and possibly more I/O if these fall out of the buffer cache.
139
+ private val bypassMergeThreshold = conf.getInt(" spark.shuffle.sort.bypassMergeThreshold" , 200 )
140
+ private [collection] val bypassMergeSort = // private[collection] for testing
141
+ (numPartitions <= bypassMergeThreshold && aggregator.isEmpty && ordering.isEmpty)
142
+
143
+ // Array of file writers for each partition, used if bypassMergeSort is true
144
+ private var partitionWriters : Array [BlockObjectWriter ] = null
145
+
127
146
// A comparator for keys K that orders them within a partition to allow aggregation or sorting.
128
147
// Can be a partial ordering by hash code if a total ordering is not provided through by the
129
148
// user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some
@@ -137,7 +156,14 @@ private[spark] class ExternalSorter[K, V, C](
137
156
}
138
157
})
139
158
140
- // A comparator for (Int, K) elements that orders them by partition and then possibly by key
159
+ // A comparator for (Int, K) pairs that orders them by only their partition ID
160
+ private val partitionComparator : Comparator [(Int , K )] = new Comparator [(Int , K )] {
161
+ override def compare (a : (Int , K ), b : (Int , K )): Int = {
162
+ a._1 - b._1
163
+ }
164
+ }
165
+
166
+ // A comparator that orders (Int, K) pairs by partition ID and then possibly by key
141
167
private val partitionKeyComparator : Comparator [(Int , K )] = {
142
168
if (ordering.isDefined || aggregator.isDefined) {
143
169
// Sort by partition ID then key comparator
@@ -153,11 +179,7 @@ private[spark] class ExternalSorter[K, V, C](
153
179
}
154
180
} else {
155
181
// Just sort it by partition ID
156
- new Comparator [(Int , K )] {
157
- override def compare (a : (Int , K ), b : (Int , K )): Int = {
158
- a._1 - b._1
159
- }
160
- }
182
+ partitionComparator
161
183
}
162
184
}
163
185
@@ -242,6 +264,38 @@ private[spark] class ExternalSorter[K, V, C](
242
264
val threadId = Thread .currentThread().getId
243
265
logInfo(" Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)"
244
266
.format(threadId, memorySize / (1024 * 1024 ), spillCount, if (spillCount > 1 ) " s" else " " ))
267
+
268
+ if (bypassMergeSort) {
269
+ spillToPartitionFiles(collection)
270
+ } else {
271
+ spillToMergeableFile(collection)
272
+ }
273
+
274
+ if (usingMap) {
275
+ map = new SizeTrackingAppendOnlyMap [(Int , K ), C ]
276
+ } else {
277
+ buffer = new SizeTrackingPairBuffer [(Int , K ), C ]
278
+ }
279
+
280
+ // Release our memory back to the shuffle pool so that other threads can grab it
281
+ shuffleMemoryManager.release(myMemoryThreshold)
282
+ myMemoryThreshold = 0
283
+
284
+ _memoryBytesSpilled += memorySize
285
+ }
286
+
287
+ /**
288
+ * Spill our in-memory collection to a sorted file that we can merge later (normal code path).
289
+ * We add this file into spilledFiles to find it later.
290
+ *
291
+ * Alternatively, if bypassMergeSort is true, we spill to separate files for each partition.
292
+ * See spillToPartitionedFiles() for that code path.
293
+ *
294
+ * @param collection whichever collection we're using (map or buffer)
295
+ */
296
+ private def spillToMergeableFile (collection : SizeTrackingPairCollection [(Int , K ), C ]): Unit = {
297
+ assert(! bypassMergeSort)
298
+
245
299
val (blockId, file) = diskBlockManager.createTempBlock()
246
300
curWriteMetrics = new ShuffleWriteMetrics ()
247
301
var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
@@ -304,18 +358,35 @@ private[spark] class ExternalSorter[K, V, C](
304
358
}
305
359
}
306
360
307
- if (usingMap) {
308
- map = new SizeTrackingAppendOnlyMap [(Int , K ), C ]
309
- } else {
310
- buffer = new SizeTrackingPairBuffer [(Int , K ), C ]
311
- }
361
+ spills.append(SpilledFile (file, blockId, batchSizes.toArray, elementsPerPartition))
362
+ }
312
363
313
- // Release our memory back to the shuffle pool so that other threads can grab it
314
- shuffleMemoryManager.release(myMemoryThreshold)
315
- myMemoryThreshold = 0
364
+ /**
365
+ * Spill our in-memory collection to separate files, one for each partition. This is used when
366
+ * there's no aggregator and ordering and the number of partitions is small, because it allows
367
+ * writePartitionedFile to just concatenate files without deserializing data.
368
+ *
369
+ * @param collection whichever collection we're using (map or buffer)
370
+ */
371
+ private def spillToPartitionFiles (collection : SizeTrackingPairCollection [(Int , K ), C ]): Unit = {
372
+ assert(bypassMergeSort)
373
+
374
+ // Create our file writers if we haven't done so yet
375
+ if (partitionWriters == null ) {
376
+ partitionWriters = Array .fill(numPartitions) {
377
+ val (blockId, file) = diskBlockManager.createTempBlock()
378
+ blockManager.getDiskWriter(blockId, file, ser, fileBufferSize).open()
379
+ }
380
+ }
316
381
317
- spills.append(SpilledFile (file, blockId, batchSizes.toArray, elementsPerPartition))
318
- _memoryBytesSpilled += memorySize
382
+ val it = collection.iterator // No need to sort stuff, just write each element out
383
+ while (it.hasNext) {
384
+ val elem = it.next()
385
+ val partitionId = elem._1._1
386
+ val key = elem._1._2
387
+ val value = elem._2
388
+ partitionWriters(partitionId).write((key, value))
389
+ }
319
390
}
320
391
321
392
/**
@@ -479,7 +550,6 @@ private[spark] class ExternalSorter[K, V, C](
479
550
480
551
skipToNextPartition()
481
552
482
-
483
553
// Intermediate file and deserializer streams that read from exactly one batch
484
554
// This guards against pre-fetching and other arbitrary behavior of higher level streams
485
555
var fileStream : FileInputStream = null
@@ -619,23 +689,25 @@ private[spark] class ExternalSorter[K, V, C](
619
689
def partitionedIterator : Iterator [(Int , Iterator [Product2 [K , C ]])] = {
620
690
val usingMap = aggregator.isDefined
621
691
val collection : SizeTrackingPairCollection [(Int , K ), C ] = if (usingMap) map else buffer
622
- if (spills.isEmpty) {
692
+ if (spills.isEmpty && partitionWriters == null ) {
623
693
// Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
624
694
// we don't even need to sort by anything other than partition ID
625
695
if (! ordering.isDefined) {
626
- // The user isn't requested sorted keys, so only sort by partition ID, not key
627
- val partitionComparator = new Comparator [(Int , K )] {
628
- override def compare (a : (Int , K ), b : (Int , K )): Int = {
629
- a._1 - b._1
630
- }
631
- }
696
+ // The user hasn't requested sorted keys, so only sort by partition ID, not key
632
697
groupByPartition(collection.destructiveSortedIterator(partitionComparator))
633
698
} else {
634
699
// We do need to sort by both partition ID and key
635
700
groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator))
636
701
}
702
+ } else if (bypassMergeSort) {
703
+ // Read data from each partition file and merge it together with the data in memory;
704
+ // note that there's no ordering or aggregator in this case -- we just partition objects
705
+ val collIter = groupByPartition(collection.destructiveSortedIterator(partitionComparator))
706
+ collIter.map { case (partitionId, values) =>
707
+ (partitionId, values ++ readPartitionFile(partitionWriters(partitionId)))
708
+ }
637
709
} else {
638
- // General case: merge spilled and in-memory data
710
+ // Merge spilled and in-memory data
639
711
merge(spills, collection.destructiveSortedIterator(partitionKeyComparator))
640
712
}
641
713
}
@@ -653,7 +725,7 @@ private[spark] class ExternalSorter[K, V, C](
653
725
*
654
726
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
655
727
* @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
656
- * @return array of lengths, in bytes, for each partition of the file (for map output tracker)
728
+ * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
657
729
*/
658
730
def writePartitionedFile (blockId : BlockId , context : TaskContext ): Array [Long ] = {
659
731
val outputFile = blockManager.diskBlockManager.getFile(blockId)
@@ -666,21 +738,51 @@ private[spark] class ExternalSorter[K, V, C](
666
738
var totalBytes = 0L
667
739
var totalTime = 0L
668
740
669
- for ((id, elements) <- this .partitionedIterator) {
670
- if (elements.hasNext) {
671
- val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize)
672
- for (elem <- elements) {
673
- writer.write(elem)
741
+ if (bypassMergeSort && partitionWriters != null ) {
742
+ // We decided to write separate files for each partition, so just concatenate them. To keep
743
+ // this simple we spill out the current in-memory collection so that everything is in files.
744
+ spillToPartitionFiles(if (aggregator.isDefined) map else buffer)
745
+ partitionWriters.foreach(_.commitAndClose())
746
+ var out : FileOutputStream = null
747
+ var in : FileInputStream = null
748
+ try {
749
+ out = new FileOutputStream (outputFile)
750
+ for (i <- 0 until numPartitions) {
751
+ val file = partitionWriters(i).fileSegment().file
752
+ in = new FileInputStream (file)
753
+ org.apache.spark.util.Utils .copyStream(in, out)
754
+ in.close()
755
+ in = null
756
+ lengths(i) = file.length()
757
+ offsets(i + 1 ) = offsets(i) + lengths(i)
758
+ }
759
+ } finally {
760
+ if (out != null ) {
761
+ out.close()
762
+ }
763
+ if (in != null ) {
764
+ in.close()
765
+ }
766
+ }
767
+ } else {
768
+ // Either we're not bypassing merge-sort or we have only in-memory data; get an iterator by
769
+ // partition and just write everything directly.
770
+ for ((id, elements) <- this .partitionedIterator) {
771
+ if (elements.hasNext) {
772
+ val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize)
773
+ for (elem <- elements) {
774
+ writer.write(elem)
775
+ }
776
+ writer.commitAndClose()
777
+ val segment = writer.fileSegment()
778
+ offsets(id + 1 ) = segment.offset + segment.length
779
+ lengths(id) = segment.length
780
+ totalTime += writer.timeWriting()
781
+ totalBytes += segment.length
782
+ } else {
783
+ // The partition is empty; don't create a new writer to avoid writing headers, etc
784
+ offsets(id + 1 ) = offsets(id)
674
785
}
675
- writer.commitAndClose()
676
- val segment = writer.fileSegment()
677
- offsets(id + 1 ) = segment.offset + segment.length
678
- lengths(id) = segment.length
679
- totalTime += writer.timeWriting()
680
- totalBytes += segment.length
681
- } else {
682
- // The partition is empty; don't create a new writer to avoid writing headers, etc
683
- offsets(id + 1 ) = offsets(id)
684
786
}
685
787
}
686
788
@@ -711,9 +813,26 @@ private[spark] class ExternalSorter[K, V, C](
711
813
lengths
712
814
}
713
815
816
+ /**
817
+ * Read a partition file back as an iterator (used in our iterator method)
818
+ */
819
+ def readPartitionFile (writer : BlockObjectWriter ): Iterator [Product2 [K , C ]] = {
820
+ if (writer.isOpen) {
821
+ writer.commitAndClose()
822
+ }
823
+ blockManager.getLocalFromDisk(writer.blockId, ser).get.asInstanceOf [Iterator [Product2 [K , C ]]]
824
+ }
825
+
714
826
def stop (): Unit = {
715
827
spills.foreach(s => s.file.delete())
716
828
spills.clear()
829
+ if (partitionWriters != null ) {
830
+ partitionWriters.foreach { w =>
831
+ w.revertPartialWritesAndClose()
832
+ diskBlockManager.getFile(w.blockId).delete()
833
+ }
834
+ partitionWriters = null
835
+ }
717
836
}
718
837
719
838
def memoryBytesSpilled : Long = _memoryBytesSpilled
0 commit comments