1717
1818package org .apache .spark .util .collection
1919
20- import org .apache .spark .{SparkEnv , Aggregator , Logging , Partitioner }
21- import org .apache .spark .serializer .Serializer
20+ import java .io ._
2221
2322import scala .collection .mutable .ArrayBuffer
23+
24+ import com .google .common .io .ByteStreams
25+
26+ import org .apache .spark .{Aggregator , SparkEnv , Logging , Partitioner }
27+ import org .apache .spark .serializer .Serializer
2428import org .apache .spark .storage .BlockId
25- import java .io .File
2629
2730/**
2831 * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -49,6 +52,7 @@ private[spark] class ExternalSorter[K, V, C](
4952 private val blockManager = SparkEnv .get.blockManager
5053 private val diskBlockManager = blockManager.diskBlockManager
5154 private val ser = Serializer .getSerializer(serializer.getOrElse(null ))
55+ private val serInstance = ser.newInstance()
5256
5357 private val conf = SparkEnv .get.conf
5458 private val fileBufferSize = conf.getInt(" spark.shuffle.file.buffer.kb" , 100 ) * 1024
@@ -232,7 +236,7 @@ private[spark] class ExternalSorter[K, V, C](
232236 // TODO: merge intermediate results if they are sorted by the comparator
233237 val readers = spills.map(new SpillReader (_))
234238 (0 until numPartitions).iterator.map { p =>
235- (p, readers.iterator.flatMap(_.readPartition(p )))
239+ (p, readers.iterator.flatMap(_.readNextPartition( )))
236240 }
237241 }
238242
@@ -241,7 +245,92 @@ private[spark] class ExternalSorter[K, V, C](
241245 * partitions to be requested in order.
242246 */
243247 private class SpillReader (spill : SpilledFile ) {
244- def readPartition (id : Int ): Iterator [Product2 [K , C ]] = ???
248+ val fileStream = new FileInputStream (spill.file)
249+ val bufferedStream = new BufferedInputStream (fileStream, fileBufferSize)
250+
251+ // An intermediate stream that reads from exactly one batch
252+ // This guards against pre-fetching and other arbitrary behavior of higher level streams
253+ var batchStream = nextBatchStream()
254+ var compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
255+ var deserStream = serInstance.deserializeStream(compressedStream)
256+ var nextItem : (K , C ) = null
257+ var finished = false
258+
259+ // Track which partition and which batch stream we're in
260+ var partitionId = 0
261+ var indexInPartition = - 1L // Just to make sure we start at index 0
262+ var batchStreamsRead = 0
263+ var indexInBatch = - 1
264+
265+ /** Construct a stream that only reads from the next batch */
266+ def nextBatchStream (): InputStream = {
267+ if (batchStreamsRead < spill.serializerBatchSizes.length) {
268+ batchStreamsRead += 1
269+ ByteStreams .limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1 ))
270+ } else {
271+ // No more batches left
272+ bufferedStream
273+ }
274+ }
275+
276+ /**
277+ * Return the next (K, C) pair from the deserialization stream and update partitionId,
278+ * indexInPartition, indexInBatch and such to match its location.
279+ *
280+ * If the current batch is drained, construct a stream for the next batch and read from it.
281+ * If no more pairs are left, return null.
282+ */
283+ private def readNextItem (): (K , C ) = {
284+ try {
285+ if (finished) {
286+ return null
287+ }
288+ // Start reading the next batch if we're done with this one
289+ indexInBatch += 1
290+ if (indexInBatch == serializerBatchSize) {
291+ batchStream = nextBatchStream()
292+ compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
293+ deserStream = serInstance.deserializeStream(compressedStream)
294+ indexInBatch = 0
295+ }
296+ // Update the partition location of the element we're reading
297+ indexInPartition += 1
298+ while (indexInPartition == spill.elementsPerPartition(partitionId)) {
299+ partitionId += 1
300+ indexInPartition = 0
301+ }
302+ val k = deserStream.readObject().asInstanceOf [K ]
303+ val c = deserStream.readObject().asInstanceOf [C ]
304+ (k, c)
305+ } catch {
306+ case e : EOFException =>
307+ finished = true
308+ deserStream.close()
309+ null
310+ }
311+ }
312+
313+ var nextPartitionToRead = 0
314+
315+ def readNextPartition (): Iterator [Product2 [K , C ]] = new Iterator [Product2 [K , C ]] {
316+ val myPartition = nextPartitionToRead
317+ nextPartitionToRead += 1
318+
319+ override def hasNext : Boolean = {
320+ if (nextItem == null ) {
321+ nextItem = readNextItem()
322+ }
323+ // Check that we're still in the right partition; will be numPartitions at EOF
324+ partitionId == myPartition
325+ }
326+
327+ override def next (): Product2 [K , C ] = {
328+ if (! hasNext) {
329+ throw new NoSuchElementException
330+ }
331+ nextItem
332+ }
333+ }
245334 }
246335
247336 /**
0 commit comments