Skip to content

Commit 3a56341

Browse files
committed
More partial work towards sort-based shuffle
1 parent 7a0895d commit 3a56341

File tree

3 files changed

+123
-24
lines changed

3 files changed

+123
-24
lines changed

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ import org.apache.spark.shuffle.{ShuffleWriter, BaseShuffleHandle}
2121
import org.apache.spark.{SparkEnv, Logging, TaskContext}
2222
import org.apache.spark.scheduler.MapStatus
2323
import org.apache.spark.serializer.Serializer
24+
import org.apache.spark.util.collection.ExternalSorter
2425

25-
private[spark] class SortShuffleWriter[K, V](
26-
handle: BaseShuffleHandle[K, V, _],
26+
private[spark] class SortShuffleWriter[K, V, C](
27+
handle: BaseShuffleHandle[K, V, C],
2728
mapId: Int,
2829
context: TaskContext)
2930
extends ShuffleWriter[K, V] with Logging {
@@ -38,19 +39,27 @@ private[spark] class SortShuffleWriter[K, V](
3839

3940
/** Write a bunch of records to this task's output */
4041
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
41-
val iter = if (dep.aggregator.isDefined) {
42+
var sorter: ExternalSorter[K, V, _] = null
43+
44+
val partitions: Iterator[(Int, Iterator[Product2[K, _]])] = {
4245
if (dep.mapSideCombine) {
43-
// TODO: This does an external merge-sort if the data is highly combinable, and then we
44-
// do another one later to sort them by output partition. We can improve this by doing
45-
// the merging as part of the SortedFileWriter.
46-
dep.aggregator.get.combineValuesByKey(records, context)
46+
if (!dep.aggregator.isDefined) {
47+
throw new IllegalStateException("Aggregator is empty for map-side combine")
48+
}
49+
sorter = new ExternalSorter[K, V, C](
50+
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
51+
sorter.write(records)
52+
sorter.partitionedIterator
4753
} else {
48-
records
54+
sorter = new ExternalSorter[K, V, V](
55+
None, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
56+
sorter.write(records)
57+
sorter.partitionedIterator
4958
}
50-
} else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
51-
throw new IllegalStateException("Aggregator is empty for map-side combine")
52-
} else {
53-
records
59+
}
60+
61+
for ((id, elements) <- partitions) {
62+
5463
}
5564

5665
???

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717

1818
package org.apache.spark.util.collection
1919

20-
import org.apache.spark.{SparkEnv, Aggregator, Logging, Partitioner}
21-
import org.apache.spark.serializer.Serializer
20+
import java.io._
2221

2322
import scala.collection.mutable.ArrayBuffer
23+
24+
import com.google.common.io.ByteStreams
25+
26+
import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner}
27+
import org.apache.spark.serializer.Serializer
2428
import org.apache.spark.storage.BlockId
25-
import java.io.File
2629

2730
/**
2831
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -49,6 +52,7 @@ private[spark] class ExternalSorter[K, V, C](
4952
private val blockManager = SparkEnv.get.blockManager
5053
private val diskBlockManager = blockManager.diskBlockManager
5154
private val ser = Serializer.getSerializer(serializer.getOrElse(null))
55+
private val serInstance = ser.newInstance()
5256

5357
private val conf = SparkEnv.get.conf
5458
private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
@@ -232,7 +236,7 @@ private[spark] class ExternalSorter[K, V, C](
232236
// TODO: merge intermediate results if they are sorted by the comparator
233237
val readers = spills.map(new SpillReader(_))
234238
(0 until numPartitions).iterator.map { p =>
235-
(p, readers.iterator.flatMap(_.readPartition(p)))
239+
(p, readers.iterator.flatMap(_.readNextPartition()))
236240
}
237241
}
238242

@@ -241,7 +245,92 @@ private[spark] class ExternalSorter[K, V, C](
241245
* partitions to be requested in order.
242246
*/
243247
private class SpillReader(spill: SpilledFile) {
244-
def readPartition(id: Int): Iterator[Product2[K, C]] = ???
248+
val fileStream = new FileInputStream(spill.file)
249+
val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)
250+
251+
// An intermediate stream that reads from exactly one batch
252+
// This guards against pre-fetching and other arbitrary behavior of higher level streams
253+
var batchStream = nextBatchStream()
254+
var compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
255+
var deserStream = serInstance.deserializeStream(compressedStream)
256+
var nextItem: (K, C) = null
257+
var finished = false
258+
259+
// Track which partition and which batch stream we're in
260+
var partitionId = 0
261+
var indexInPartition = -1L // Just to make sure we start at index 0
262+
var batchStreamsRead = 0
263+
var indexInBatch = -1
264+
265+
/** Construct a stream that only reads from the next batch */
266+
def nextBatchStream(): InputStream = {
267+
if (batchStreamsRead < spill.serializerBatchSizes.length) {
268+
batchStreamsRead += 1
269+
ByteStreams.limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1))
270+
} else {
271+
// No more batches left
272+
bufferedStream
273+
}
274+
}
275+
276+
/**
277+
* Return the next (K, C) pair from the deserialization stream and update partitionId,
278+
* indexInPartition, indexInBatch and such to match its location.
279+
*
280+
* If the current batch is drained, construct a stream for the next batch and read from it.
281+
* If no more pairs are left, return null.
282+
*/
283+
private def readNextItem(): (K, C) = {
284+
try {
285+
if (finished) {
286+
return null
287+
}
288+
// Start reading the next batch if we're done with this one
289+
indexInBatch += 1
290+
if (indexInBatch == serializerBatchSize) {
291+
batchStream = nextBatchStream()
292+
compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
293+
deserStream = serInstance.deserializeStream(compressedStream)
294+
indexInBatch = 0
295+
}
296+
// Update the partition location of the element we're reading
297+
indexInPartition += 1
298+
while (indexInPartition == spill.elementsPerPartition(partitionId)) {
299+
partitionId += 1
300+
indexInPartition = 0
301+
}
302+
val k = deserStream.readObject().asInstanceOf[K]
303+
val c = deserStream.readObject().asInstanceOf[C]
304+
(k, c)
305+
} catch {
306+
case e: EOFException =>
307+
finished = true
308+
deserStream.close()
309+
null
310+
}
311+
}
312+
313+
var nextPartitionToRead = 0
314+
315+
def readNextPartition(): Iterator[Product2[K, C]] = new Iterator[Product2[K, C]] {
316+
val myPartition = nextPartitionToRead
317+
nextPartitionToRead += 1
318+
319+
override def hasNext: Boolean = {
320+
if (nextItem == null) {
321+
nextItem = readNextItem()
322+
}
323+
// Check that we're still in the right partition; will be numPartitions at EOF
324+
partitionId == myPartition
325+
}
326+
327+
override def next(): Product2[K, C] = {
328+
if (!hasNext) {
329+
throw new NoSuchElementException
330+
}
331+
nextItem
332+
}
333+
}
245334
}
246335

247336
/**

core/src/main/scala/org/apache/spark/util/collection/SizeTrackingBuffer.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package org.apache.spark.util.collection
2020
import java.util.Arrays
2121
import java.util.Comparator
2222

23-
import scala.reflect.ClassTag
2423
import scala.collection.mutable.ArrayBuffer
2524

2625
import org.apache.spark.util.SizeEstimator
@@ -32,13 +31,15 @@ import org.apache.spark.util.SizeEstimator
3231
*
3332
* The tracking code is copied from SizeTrackingAppendOnlyMap -- we'll factor that out soon.
3433
*/
35-
private[spark] class SizeTrackingBuffer[T: ClassTag](initialCapacity: Int = 64)
34+
private[spark] class SizeTrackingBuffer[T <: AnyRef](initialCapacity: Int = 64)
3635
extends SizeTrackingCollection[T]
3736
{
38-
// Basic growable array data structure
37+
// Basic growable array data structure. NOTE: We use an Array of AnyRef because Arrays.sort()
38+
// is not easy to call on an Array[T], and Scala doesn't provide a great way to sort a generic
39+
// array using a Comparator.
3940
private var capacity = initialCapacity
4041
private var curSize = 0
41-
private var data = new Array[T](initialCapacity)
42+
private var data = new Array[AnyRef](initialCapacity)
4243

4344
// Size-tracking variables: we maintain a sequence of samples since the size of the collection
4445
// depends on both the array and how many of its elements are filled. We reset this each time
@@ -91,7 +92,7 @@ private[spark] class SizeTrackingBuffer[T: ClassTag](initialCapacity: Int = 64)
9192
override def next(): T = {
9293
val elem = data(pos)
9394
pos += 1
94-
elem
95+
elem.asInstanceOf[T]
9596
}
9697
}
9798

@@ -112,7 +113,7 @@ private[spark] class SizeTrackingBuffer[T: ClassTag](initialCapacity: Int = 64)
112113
throw new Exception("Can't grow buffer beyond 2^30 elements")
113114
}
114115
val newCapacity = capacity * 2
115-
val newArray = new Array[T](newCapacity)
116+
val newArray = new Array[AnyRef](newCapacity)
116117
System.arraycopy(data, 0, newArray, 0, capacity)
117118
data = newArray
118119
capacity = newCapacity
@@ -143,7 +144,7 @@ private[spark] class SizeTrackingBuffer[T: ClassTag](initialCapacity: Int = 64)
143144

144145
/** Iterate through the data in a given order. For this class this is not really destructive. */
145146
override def destructiveSortedIterator(cmp: Comparator[T]): Iterator[T] = {
146-
Arrays.sort(data, 0, curSize, cmp)
147+
Arrays.sort(data, 0, curSize, cmp.asInstanceOf[Comparator[AnyRef]])
147148
iterator
148149
}
149150
}

0 commit comments

Comments
 (0)