Skip to content

Commit a34b352

Browse files
committed
Fix tracking of indices within a partition in SpillReader, and add test
1 parent 03e1006 commit a34b352

File tree

2 files changed

+59
-13
lines changed

2 files changed

+59
-13
lines changed

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -480,11 +480,16 @@ private[spark] class ExternalSorter[K, V, C](
480480
val fileStream = new FileInputStream(spill.file)
481481
val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)
482482

483-
// Track which partition and which batch stream we're in
483+
// Track which partition and which batch stream we're in. These will be the indices of
484+
// the next element we will read. We'll also store the last partition read so that
485+
// readNextPartition() can figure out what partition that was from.
484486
var partitionId = 0
485-
var indexInPartition = -1L // Just to make sure we start at index 0
487+
var indexInPartition = 0L
486488
var batchStreamsRead = 0
487489
var indexInBatch = 0
490+
var lastPartitionId = 0
491+
492+
skipToNextPartition()
488493

489494
// An intermediate stream that reads from exactly one batch
490495
// This guards against pre-fetching and other arbitrary behavior of higher level streams
@@ -500,6 +505,18 @@ private[spark] class ExternalSorter[K, V, C](
500505
ByteStreams.limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1))
501506
}
502507

508+
/**
509+
* Update partitionId if we have reached the end of our current partition, possibly skipping
510+
* empty partitions on the way.
511+
*/
512+
private def skipToNextPartition() {
513+
while (partitionId < numPartitions &&
514+
indexInPartition == spill.elementsPerPartition(partitionId)) {
515+
partitionId += 1
516+
indexInPartition = 0L
517+
}
518+
}
519+
503520
/**
504521
* Return the next (K, C) pair from the deserialization stream and update partitionId,
505522
* indexInPartition, indexInBatch and such to match its location.
@@ -513,6 +530,7 @@ private[spark] class ExternalSorter[K, V, C](
513530
}
514531
val k = deserStream.readObject().asInstanceOf[K]
515532
val c = deserStream.readObject().asInstanceOf[C]
533+
lastPartitionId = partitionId
516534
// Start reading the next batch if we're done with this one
517535
indexInBatch += 1
518536
if (indexInBatch == serializerBatchSize) {
@@ -521,16 +539,11 @@ private[spark] class ExternalSorter[K, V, C](
521539
deserStream = serInstance.deserializeStream(compressedStream)
522540
indexInBatch = 0
523541
}
524-
// Update the partition location of the element we're reading, possibly skipping zero-length
525-
// partitions until we get to the next non-empty one or to EOF.
542+
// Update the partition location of the element we're reading
526543
indexInPartition += 1
527-
while (indexInPartition == spill.elementsPerPartition(partitionId)) {
528-
partitionId += 1
529-
indexInPartition = 0
530-
}
531-
if (partitionId == numPartitions - 1 &&
532-
indexInPartition == spill.elementsPerPartition(partitionId) - 1) {
533-
// This is the last element, remember that we're done
544+
skipToNextPartition()
545+
// If we've finished reading the last partition, remember that we're done
546+
if (partitionId == numPartitions) {
534547
finished = true
535548
deserStream.close()
536549
}
@@ -550,10 +563,10 @@ private[spark] class ExternalSorter[K, V, C](
550563
return false
551564
}
552565
}
553-
assert(partitionId >= myPartition)
566+
assert(lastPartitionId >= myPartition)
554567
// Check that we're still in the right partition; note that readNextItem will have returned
555568
// null at EOF above so we would've returned false there
556-
partitionId == myPartition
569+
lastPartitionId == myPartition
557570
}
558571

559572
override def next(): Product2[K, C] = {

core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,25 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
3838
val sorter = new ExternalSorter[Int, Int, Int](
3939
Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
4040
assert(sorter.iterator.toSeq === Seq())
41+
sorter.stop()
4142

4243
// Only aggregator
4344
val sorter2 = new ExternalSorter[Int, Int, Int](
4445
Some(agg), Some(new HashPartitioner(3)), None, None)
4546
assert(sorter2.iterator.toSeq === Seq())
47+
sorter2.stop()
4648

4749
// Only ordering
4850
val sorter3 = new ExternalSorter[Int, Int, Int](
4951
None, Some(new HashPartitioner(3)), Some(ord), None)
5052
assert(sorter3.iterator.toSeq === Seq())
53+
sorter3.stop()
5154

5255
// Neither aggregator nor ordering
5356
val sorter4 = new ExternalSorter[Int, Int, Int](
5457
None, Some(new HashPartitioner(3)), None, None)
5558
assert(sorter4.iterator.toSeq === Seq())
59+
sorter4.stop()
5660
}
5761

5862
test("few elements per partition") {
@@ -73,24 +77,53 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
7377
Some(agg), Some(new HashPartitioner(7)), Some(ord), None)
7478
sorter.write(elements.iterator)
7579
assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
80+
sorter.stop()
7681

7782
// Only aggregator
7883
val sorter2 = new ExternalSorter[Int, Int, Int](
7984
Some(agg), Some(new HashPartitioner(7)), None, None)
8085
sorter2.write(elements.iterator)
8186
assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
87+
sorter2.stop()
8288

8389
// Only ordering
8490
val sorter3 = new ExternalSorter[Int, Int, Int](
8591
None, Some(new HashPartitioner(7)), Some(ord), None)
8692
sorter3.write(elements.iterator)
8793
assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
94+
sorter3.stop()
8895

8996
// Neither aggregator nor ordering
9097
val sorter4 = new ExternalSorter[Int, Int, Int](
9198
None, Some(new HashPartitioner(7)), None, None)
9299
sorter4.write(elements.iterator)
93100
assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
101+
sorter4.stop()
102+
}
103+
104+
test("empty partitions with spilling") {
105+
val conf = new SparkConf(false)
106+
conf.set("spark.shuffle.memoryFraction", "0.001")
107+
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
108+
sc = new SparkContext("local", "test", conf)
109+
110+
val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
111+
val ord = implicitly[Ordering[Int]]
112+
val elements = Iterator((1, 1), (5, 5)) ++ (0 until 50000).iterator.map(x => (2, 2))
113+
114+
val sorter = new ExternalSorter[Int, Int, Int](
115+
None, Some(new HashPartitioner(7)), None, None)
116+
sorter.write(elements)
117+
assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled
118+
val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
119+
assert(iter.next() === (0, Nil))
120+
assert(iter.next() === (1, List((1, 1))))
121+
assert(iter.next() === (2, (0 until 50000).map(x => (2, 2)).toList))
122+
assert(iter.next() === (3, Nil))
123+
assert(iter.next() === (4, Nil))
124+
assert(iter.next() === (5, List((5, 5))))
125+
assert(iter.next() === (6, Nil))
126+
sorter.stop()
94127
}
95128

96129
test("spilling in local cluster") {

0 commit comments

Comments
 (0)