Skip to content

Commit 31e5d7c

Browse files
committed
Move existing logic for writing partitioned files into ExternalSorter
Also renamed ExternalSorter.write(Iterator) to insertAll, to match ExternalAppendOnlyMap
1 parent a263a7e commit 31e5d7c

File tree

4 files changed

+102
-69
lines changed

4 files changed

+102
-69
lines changed

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ private[spark] class HashShuffleReader[K, C](
5858
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
5959
// the ExternalSorter won't spill to disk.
6060
val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
61-
sorter.write(aggregatedIter)
61+
sorter.insertAll(aggregatedIter)
6262
context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
6363
context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
6464
sorter.iterator

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala

Lines changed: 15 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ private[spark] class SortShuffleWriter[K, V, C](
4444

4545
private var sorter: ExternalSorter[K, V, _] = null
4646
private var outputFile: File = null
47+
private var indexFile: File = null
4748

4849
// Are we in the process of stopping? Because map tasks can call stop() with success = true
4950
// and then call stop() with success = false if they get an exception, we want to make sure
@@ -58,77 +59,40 @@ private[spark] class SortShuffleWriter[K, V, C](
5859
/** Write a bunch of records to this task's output */
5960
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
6061
// Get an iterator with the elements for each partition ID
61-
val partitions: Iterator[(Int, Iterator[Product2[K, _]])] = {
62+
val externalSorter: ExternalSorter[K, V, _] = {
6263
if (dep.mapSideCombine) {
6364
if (!dep.aggregator.isDefined) {
6465
throw new IllegalStateException("Aggregator is empty for map-side combine")
6566
}
66-
sorter = new ExternalSorter[K, V, C](
67+
val sorter = new ExternalSorter[K, V, C](
6768
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
68-
sorter.write(records)
69-
sorter.partitionedIterator
69+
sorter.insertAll(records)
70+
sorter
7071
} else {
7172
// In this case we pass neither an aggregator nor an ordering to the sorter, because we
7273
// don't care whether the keys get sorted in each partition; that will be done on the
7374
// reduce side if the operation being run is sortByKey.
74-
sorter = new ExternalSorter[K, V, V](
75+
val sorter = new ExternalSorter[K, V, V](
7576
None, Some(dep.partitioner), None, dep.serializer)
76-
sorter.write(records)
77-
sorter.partitionedIterator
77+
sorter.insertAll(records)
78+
sorter
7879
}
7980
}
8081

8182
// Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later
8283
// serve different ranges of this file using an index file that we create at the end.
8384
val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0)
84-
outputFile = blockManager.diskBlockManager.getFile(blockId)
85-
86-
// Track location of each range in the output file
87-
val offsets = new Array[Long](numPartitions + 1)
88-
val lengths = new Array[Long](numPartitions)
89-
90-
for ((id, elements) <- partitions) {
91-
if (elements.hasNext) {
92-
val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize,
93-
writeMetrics)
94-
for (elem <- elements) {
95-
writer.write(elem)
96-
}
97-
writer.commitAndClose()
98-
val segment = writer.fileSegment()
99-
offsets(id + 1) = segment.offset + segment.length
100-
lengths(id) = segment.length
101-
} else {
102-
// The partition is empty; don't create a new writer to avoid writing headers, etc
103-
offsets(id + 1) = offsets(id)
104-
}
105-
}
10685

107-
context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
108-
context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
86+
outputFile = blockManager.diskBlockManager.getFile(blockId)
87+
indexFile = blockManager.diskBlockManager.getFile(blockId.name + ".index")
10988

110-
// Write an index file with the offsets of each block, plus a final offset at the end for the
111-
// end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure
112-
// out where each block begins and ends.
113-
114-
val diskBlockManager = blockManager.diskBlockManager
115-
val indexFile = diskBlockManager.getFile(blockId.name + ".index")
116-
val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
117-
try {
118-
var i = 0
119-
while (i < numPartitions + 1) {
120-
out.writeLong(offsets(i))
121-
i += 1
122-
}
123-
} finally {
124-
out.close()
125-
}
89+
val partitionLengths = sorter.writePartitionedFile(blockId, context)
12690

12791
// Register our map output with the ShuffleBlockManager, which handles cleaning it over time
12892
blockManager.shuffleBlockManager.addCompletedMap(dep.shuffleId, mapId, numPartitions)
12993

13094
mapStatus = new MapStatus(blockManager.blockManagerId,
131-
lengths.map(MapOutputTracker.compressSize))
95+
partitionLengths.map(MapOutputTracker.compressSize))
13296
}
13397

13498
/** Close this writer, passing along whether the map completed */
@@ -145,6 +109,9 @@ private[spark] class SortShuffleWriter[K, V, C](
145109
if (outputFile != null) {
146110
outputFile.delete()
147111
}
112+
if (indexFile != null) {
113+
indexFile.delete()
114+
}
148115
return None
149116
}
150117
} finally {

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import scala.collection.mutable
2525

2626
import com.google.common.io.ByteStreams
2727

28-
import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner}
28+
import org.apache.spark._
2929
import org.apache.spark.serializer.{DeserializationStream, Serializer}
3030
import org.apache.spark.storage.BlockId
3131
import org.apache.spark.executor.ShuffleWriteMetrics
@@ -171,7 +171,7 @@ private[spark] class ExternalSorter[K, V, C](
171171
elementsPerPartition: Array[Long])
172172
private val spills = new ArrayBuffer[SpilledFile]
173173

174-
def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
174+
def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = {
175175
// TODO: stop combining if we find that the reduction factor isn't high
176176
val shouldCombine = aggregator.isDefined
177177

@@ -645,6 +645,72 @@ private[spark] class ExternalSorter[K, V, C](
645645
*/
646646
def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2)
647647

648+
/**
649+
* Write all the data added into this ExternalSorter into a file in the disk store, creating
650+
* an .index file for it as well with the offsets of each partition. This is called by the
651+
* SortShuffleWriter and can go through an efficient path of just concatenating binary files
652+
* if we decided to avoid merge-sorting.
653+
*
654+
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
655+
* @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
656+
* @return array of lengths, in bytes, for each partition of the file (for map output tracker)
657+
*/
658+
def writePartitionedFile(blockId: BlockId, context: TaskContext): Array[Long] = {
659+
val outputFile = blockManager.diskBlockManager.getFile(blockId)
660+
661+
// Track location of each range in the output file
662+
val offsets = new Array[Long](numPartitions + 1)
663+
val lengths = new Array[Long](numPartitions)
664+
665+
// Statistics
666+
var totalBytes = 0L
667+
var totalTime = 0L
668+
669+
for ((id, elements) <- this.partitionedIterator) {
670+
if (elements.hasNext) {
671+
val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize)
672+
for (elem <- elements) {
673+
writer.write(elem)
674+
}
675+
writer.commitAndClose()
676+
val segment = writer.fileSegment()
677+
offsets(id + 1) = segment.offset + segment.length
678+
lengths(id) = segment.length
679+
totalTime += writer.timeWriting()
680+
totalBytes += segment.length
681+
} else {
682+
// The partition is empty; don't create a new writer to avoid writing headers, etc
683+
offsets(id + 1) = offsets(id)
684+
}
685+
}
686+
687+
val shuffleMetrics = new ShuffleWriteMetrics
688+
shuffleMetrics.shuffleBytesWritten = totalBytes
689+
shuffleMetrics.shuffleWriteTime = totalTime
690+
context.taskMetrics.shuffleWriteMetrics = Some(shuffleMetrics)
691+
context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
692+
context.taskMetrics.diskBytesSpilled += diskBytesSpilled
693+
694+
// Write an index file with the offsets of each block, plus a final offset at the end for the
695+
// end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure
696+
// out where each block begins and ends.
697+
698+
val diskBlockManager = blockManager.diskBlockManager
699+
val indexFile = diskBlockManager.getFile(blockId.name + ".index")
700+
val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
701+
try {
702+
var i = 0
703+
while (i < numPartitions + 1) {
704+
out.writeLong(offsets(i))
705+
i += 1
706+
}
707+
} finally {
708+
out.close()
709+
}
710+
711+
lengths
712+
}
713+
648714
def stop(): Unit = {
649715
spills.foreach(s => s.file.delete())
650716
spills.clear()

core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -86,28 +86,28 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
8686
// Both aggregator and ordering
8787
val sorter = new ExternalSorter[Int, Int, Int](
8888
Some(agg), Some(new HashPartitioner(7)), Some(ord), None)
89-
sorter.write(elements.iterator)
89+
sorter.insertAll(elements.iterator)
9090
assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
9191
sorter.stop()
9292

9393
// Only aggregator
9494
val sorter2 = new ExternalSorter[Int, Int, Int](
9595
Some(agg), Some(new HashPartitioner(7)), None, None)
96-
sorter2.write(elements.iterator)
96+
sorter2.insertAll(elements.iterator)
9797
assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
9898
sorter2.stop()
9999

100100
// Only ordering
101101
val sorter3 = new ExternalSorter[Int, Int, Int](
102102
None, Some(new HashPartitioner(7)), Some(ord), None)
103-
sorter3.write(elements.iterator)
103+
sorter3.insertAll(elements.iterator)
104104
assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
105105
sorter3.stop()
106106

107107
// Neither aggregator nor ordering
108108
val sorter4 = new ExternalSorter[Int, Int, Int](
109109
None, Some(new HashPartitioner(7)), None, None)
110-
sorter4.write(elements.iterator)
110+
sorter4.insertAll(elements.iterator)
111111
assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
112112
sorter4.stop()
113113
}
@@ -124,7 +124,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
124124

125125
val sorter = new ExternalSorter[Int, Int, Int](
126126
None, Some(new HashPartitioner(7)), None, None)
127-
sorter.write(elements)
127+
sorter.insertAll(elements)
128128
assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled
129129
val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
130130
assert(iter.next() === (0, Nil))
@@ -287,13 +287,13 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
287287
val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
288288

289289
val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
290-
sorter.write((0 until 100000).iterator.map(i => (i, i)))
290+
sorter.insertAll((0 until 100000).iterator.map(i => (i, i)))
291291
assert(diskBlockManager.getAllFiles().length > 0)
292292
sorter.stop()
293293
assert(diskBlockManager.getAllBlocks().length === 0)
294294

295295
val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
296-
sorter2.write((0 until 100000).iterator.map(i => (i, i)))
296+
sorter2.insertAll((0 until 100000).iterator.map(i => (i, i)))
297297
assert(diskBlockManager.getAllFiles().length > 0)
298298
assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet)
299299
sorter2.stop()
@@ -309,7 +309,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
309309

310310
val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
311311
intercept[SparkException] {
312-
sorter.write((0 until 100000).iterator.map(i => {
312+
sorter.insertAll((0 until 100000).iterator.map(i => {
313313
if (i == 99990) {
314314
throw new SparkException("Intentional failure")
315315
}
@@ -365,7 +365,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
365365
sc = new SparkContext("local", "test", conf)
366366

367367
val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
368-
sorter.write((0 until 100000).iterator.map(i => (i / 4, i)))
368+
sorter.insertAll((0 until 100000).iterator.map(i => (i / 4, i)))
369369
val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
370370
val expected = (0 until 3).map(p => {
371371
(p, (0 until 100000).map(i => (i / 4, i)).filter(_._1 % 3 == p).toSet)
@@ -381,7 +381,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
381381

382382
val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
383383
val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None)
384-
sorter.write((0 until 100).iterator.map(i => (i / 2, i)))
384+
sorter.insertAll((0 until 100).iterator.map(i => (i / 2, i)))
385385
val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
386386
val expected = (0 until 3).map(p => {
387387
(p, (0 until 50).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
@@ -397,7 +397,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
397397

398398
val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
399399
val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None)
400-
sorter.write((0 until 100000).iterator.map(i => (i / 2, i)))
400+
sorter.insertAll((0 until 100000).iterator.map(i => (i / 2, i)))
401401
val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
402402
val expected = (0 until 3).map(p => {
403403
(p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
@@ -414,7 +414,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
414414
val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
415415
val ord = implicitly[Ordering[Int]]
416416
val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
417-
sorter.write((0 until 100000).iterator.map(i => (i / 2, i)))
417+
sorter.insertAll((0 until 100000).iterator.map(i => (i / 2, i)))
418418
val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
419419
val expected = (0 until 3).map(p => {
420420
(p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
@@ -431,7 +431,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
431431
val ord = implicitly[Ordering[Int]]
432432
val sorter = new ExternalSorter[Int, Int, Int](
433433
None, Some(new HashPartitioner(3)), Some(ord), None)
434-
sorter.write((0 until 100).iterator.map(i => (i, i)))
434+
sorter.insertAll((0 until 100).iterator.map(i => (i, i)))
435435
val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq
436436
val expected = (0 until 3).map(p => {
437437
(p, (0 until 100).map(i => (i, i)).filter(_._1 % 3 == p).toSeq)
@@ -448,7 +448,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
448448
val ord = implicitly[Ordering[Int]]
449449
val sorter = new ExternalSorter[Int, Int, Int](
450450
None, Some(new HashPartitioner(3)), Some(ord), None)
451-
sorter.write((0 until 100000).iterator.map(i => (i, i)))
451+
sorter.insertAll((0 until 100000).iterator.map(i => (i, i)))
452452
val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq
453453
val expected = (0 until 3).map(p => {
454454
(p, (0 until 100000).map(i => (i, i)).filter(_._1 % 3 == p).toSeq)
@@ -495,7 +495,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
495495
val toInsert = (1 to 100000).iterator.map(_.toString).map(s => (s, s)) ++
496496
collisionPairs.iterator ++ collisionPairs.iterator.map(_.swap)
497497

498-
sorter.write(toInsert)
498+
sorter.insertAll(toInsert)
499499

500500
// A map of collision pairs in both directions
501501
val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap
@@ -524,7 +524,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
524524
// Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes
525525
// problems if the map fails to group together the objects with the same code (SPARK-2043).
526526
val toInsert = for (i <- 1 to 10; j <- 1 to 10000) yield (FixedHashObject(j, j % 2), 1)
527-
sorter.write(toInsert.iterator)
527+
sorter.insertAll(toInsert.iterator)
528528

529529
val it = sorter.iterator
530530
var count = 0
@@ -548,7 +548,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
548548
val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners)
549549
val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None)
550550

551-
sorter.write((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue)))
551+
sorter.insertAll((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue)))
552552

553553
val it = sorter.iterator
554554
while (it.hasNext) {
@@ -572,7 +572,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
572572
val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
573573
Some(agg), None, None, None)
574574

575-
sorter.write((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator(
575+
sorter.insertAll((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator(
576576
(null.asInstanceOf[String], "1"),
577577
("1", null.asInstanceOf[String]),
578578
(null.asInstanceOf[String], null.asInstanceOf[String])

0 commit comments

Comments
 (0)