Skip to content

Commit 8e7d5ba

Browse files
committed
SPARK-2792. Fix reading too much or too little data from each stream in ExternalMap / Sorter
All these changes are from mridulm's work in apache#1609, but extracted here to fix this specific issue and make it easier to merge not 1.1. This particular set of changes is to make sure that we read exactly the right range of bytes from each spill file in EAOM: some serializers can write bytes after the last object (e.g. the TC_RESET flag in Java serialization) and that would confuse the previous code into reading it as part of the next batch. There are also improvements to cleanup to make sure files are closed. In addition to bringing in the changes to ExternalAppendOnlyMap, I also copied them to the corresponding code in ExternalSorter and updated its test suite to test for the same issues. Author: Matei Zaharia <matei@databricks.com> Closes apache#1722 from mateiz/spark-2792 and squashes the following commits: 5d4bfb5 [Matei Zaharia] Make objectStreamReset counter count the last object written too 18fe865 [Matei Zaharia] Update docs on objectStreamReset 576ee83 [Matei Zaharia] Allow objectStreamReset to be 0 0374217 [Matei Zaharia] Remove super paranoid code to close file handles bda37bb [Matei Zaharia] Implement Mridul's ExternalAppendOnlyMap fixes in ExternalSorter too 0d6dad7 [Matei Zaharia] Added Mridul's test changes for ExternalAppendOnlyMap 9a78e4b [Matei Zaharia] Add @mridulm's fixes to ExternalAppendOnlyMap for batch sizes
1 parent 59f84a9 commit 8e7d5ba

File tree

6 files changed

+194
-83
lines changed

6 files changed

+194
-83
lines changed

core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,15 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In
3535
/**
3636
* Calling reset to avoid memory leak:
3737
* http://stackoverflow.com/questions/1281549/memory-leak-traps-in-the-java-standard-api
38-
* But only call it every 10,000th time to avoid bloated serialization streams (when
38+
* But only call it every 100th time to avoid bloated serialization streams (when
3939
* the stream 'resets' object class descriptions have to be re-written)
4040
*/
4141
def writeObject[T: ClassTag](t: T): SerializationStream = {
4242
objOut.writeObject(t)
43+
counter += 1
4344
if (counterReset > 0 && counter >= counterReset) {
4445
objOut.reset()
4546
counter = 0
46-
} else {
47-
counter += 1
4847
}
4948
this
5049
}

core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala

Lines changed: 65 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.util.collection
1919

20-
import java.io.{InputStream, BufferedInputStream, FileInputStream, File, Serializable, EOFException}
20+
import java.io._
2121
import java.util.Comparator
2222

2323
import scala.collection.BufferedIterator
@@ -28,7 +28,7 @@ import com.google.common.io.ByteStreams
2828

2929
import org.apache.spark.{Logging, SparkEnv}
3030
import org.apache.spark.annotation.DeveloperApi
31-
import org.apache.spark.serializer.Serializer
31+
import org.apache.spark.serializer.{DeserializationStream, Serializer}
3232
import org.apache.spark.storage.{BlockId, BlockManager}
3333
import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator
3434

@@ -199,13 +199,16 @@ class ExternalAppendOnlyMap[K, V, C](
199199

200200
// Flush the disk writer's contents to disk, and update relevant variables
201201
def flush() = {
202-
writer.commitAndClose()
203-
val bytesWritten = writer.bytesWritten
202+
val w = writer
203+
writer = null
204+
w.commitAndClose()
205+
val bytesWritten = w.bytesWritten
204206
batchSizes.append(bytesWritten)
205207
_diskBytesSpilled += bytesWritten
206208
objectsWritten = 0
207209
}
208210

211+
var success = false
209212
try {
210213
val it = currentMap.destructiveSortedIterator(keyComparator)
211214
while (it.hasNext) {
@@ -215,16 +218,28 @@ class ExternalAppendOnlyMap[K, V, C](
215218

216219
if (objectsWritten == serializerBatchSize) {
217220
flush()
218-
writer.close()
219221
writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
220222
}
221223
}
222224
if (objectsWritten > 0) {
223225
flush()
226+
} else if (writer != null) {
227+
val w = writer
228+
writer = null
229+
w.revertPartialWritesAndClose()
224230
}
231+
success = true
225232
} finally {
226-
// Partial failures cannot be tolerated; do not revert partial writes
227-
writer.close()
233+
if (!success) {
234+
// This code path only happens if an exception was thrown above before we set success;
235+
// close our stuff and let the exception be thrown further
236+
if (writer != null) {
237+
writer.revertPartialWritesAndClose()
238+
}
239+
if (file.exists()) {
240+
file.delete()
241+
}
242+
}
228243
}
229244

230245
currentMap = new SizeTrackingAppendOnlyMap[K, C]
@@ -389,27 +404,51 @@ class ExternalAppendOnlyMap[K, V, C](
389404
* An iterator that returns (K, C) pairs in sorted order from an on-disk map
390405
*/
391406
private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long])
392-
extends Iterator[(K, C)] {
393-
private val fileStream = new FileInputStream(file)
394-
private val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)
407+
extends Iterator[(K, C)]
408+
{
409+
private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1
410+
assert(file.length() == batchOffsets(batchOffsets.length - 1))
411+
412+
private var batchIndex = 0 // Which batch we're in
413+
private var fileStream: FileInputStream = null
395414

396415
// An intermediate stream that reads from exactly one batch
397416
// This guards against pre-fetching and other arbitrary behavior of higher level streams
398-
private var batchStream = nextBatchStream()
399-
private var compressedStream = blockManager.wrapForCompression(blockId, batchStream)
400-
private var deserializeStream = ser.deserializeStream(compressedStream)
417+
private var deserializeStream = nextBatchStream()
401418
private var nextItem: (K, C) = null
402419
private var objectsRead = 0
403420

404421
/**
405422
* Construct a stream that reads only from the next batch.
406423
*/
407-
private def nextBatchStream(): InputStream = {
408-
if (batchSizes.length > 0) {
409-
ByteStreams.limit(bufferedStream, batchSizes.remove(0))
424+
private def nextBatchStream(): DeserializationStream = {
425+
// Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
426+
// we're still in a valid batch.
427+
if (batchIndex < batchOffsets.length - 1) {
428+
if (deserializeStream != null) {
429+
deserializeStream.close()
430+
fileStream.close()
431+
deserializeStream = null
432+
fileStream = null
433+
}
434+
435+
val start = batchOffsets(batchIndex)
436+
fileStream = new FileInputStream(file)
437+
fileStream.getChannel.position(start)
438+
batchIndex += 1
439+
440+
val end = batchOffsets(batchIndex)
441+
442+
assert(end >= start, "start = " + start + ", end = " + end +
443+
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))
444+
445+
val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
446+
val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream)
447+
ser.deserializeStream(compressedStream)
410448
} else {
411449
// No more batches left
412-
bufferedStream
450+
cleanup()
451+
null
413452
}
414453
}
415454

@@ -424,10 +463,8 @@ class ExternalAppendOnlyMap[K, V, C](
424463
val item = deserializeStream.readObject().asInstanceOf[(K, C)]
425464
objectsRead += 1
426465
if (objectsRead == serializerBatchSize) {
427-
batchStream = nextBatchStream()
428-
compressedStream = blockManager.wrapForCompression(blockId, batchStream)
429-
deserializeStream = ser.deserializeStream(compressedStream)
430466
objectsRead = 0
467+
deserializeStream = nextBatchStream()
431468
}
432469
item
433470
} catch {
@@ -439,6 +476,9 @@ class ExternalAppendOnlyMap[K, V, C](
439476

440477
override def hasNext: Boolean = {
441478
if (nextItem == null) {
479+
if (deserializeStream == null) {
480+
return false
481+
}
442482
nextItem = readNextItem()
443483
}
444484
nextItem != null
@@ -455,7 +495,11 @@ class ExternalAppendOnlyMap[K, V, C](
455495

456496
// TODO: Ensure this gets called even if the iterator isn't drained.
457497
private def cleanup() {
458-
deserializeStream.close()
498+
batchIndex = batchOffsets.length // Prevent reading any other batch
499+
val ds = deserializeStream
500+
deserializeStream = null
501+
fileStream = null
502+
ds.close()
459503
file.delete()
460504
}
461505
}

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 75 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import scala.collection.mutable
2626
import com.google.common.io.ByteStreams
2727

2828
import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner}
29-
import org.apache.spark.serializer.Serializer
29+
import org.apache.spark.serializer.{DeserializationStream, Serializer}
3030
import org.apache.spark.storage.BlockId
3131

3232
/**
@@ -273,13 +273,16 @@ private[spark] class ExternalSorter[K, V, C](
273273
// Flush the disk writer's contents to disk, and update relevant variables.
274274
// The writer is closed at the end of this process, and cannot be reused.
275275
def flush() = {
276-
writer.commitAndClose()
277-
val bytesWritten = writer.bytesWritten
276+
val w = writer
277+
writer = null
278+
w.commitAndClose()
279+
val bytesWritten = w.bytesWritten
278280
batchSizes.append(bytesWritten)
279281
_diskBytesSpilled += bytesWritten
280282
objectsWritten = 0
281283
}
282284

285+
var success = false
283286
try {
284287
val it = collection.destructiveSortedIterator(partitionKeyComparator)
285288
while (it.hasNext) {
@@ -299,13 +302,23 @@ private[spark] class ExternalSorter[K, V, C](
299302
}
300303
if (objectsWritten > 0) {
301304
flush()
305+
} else if (writer != null) {
306+
val w = writer
307+
writer = null
308+
w.revertPartialWritesAndClose()
309+
}
310+
success = true
311+
} finally {
312+
if (!success) {
313+
// This code path only happens if an exception was thrown above before we set success;
314+
// close our stuff and let the exception be thrown further
315+
if (writer != null) {
316+
writer.revertPartialWritesAndClose()
317+
}
318+
if (file.exists()) {
319+
file.delete()
320+
}
302321
}
303-
writer.close()
304-
} catch {
305-
case e: Exception =>
306-
writer.close()
307-
file.delete()
308-
throw e
309322
}
310323

311324
if (usingMap) {
@@ -472,36 +485,58 @@ private[spark] class ExternalSorter[K, V, C](
472485
* partitions to be requested in order.
473486
*/
474487
private[this] class SpillReader(spill: SpilledFile) {
475-
val fileStream = new FileInputStream(spill.file)
476-
val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)
488+
// Serializer batch offsets; size will be batchSize.length + 1
489+
val batchOffsets = spill.serializerBatchSizes.scanLeft(0L)(_ + _)
477490

478491
// Track which partition and which batch stream we're in. These will be the indices of
479492
// the next element we will read. We'll also store the last partition read so that
480493
// readNextPartition() can figure out what partition that was from.
481494
var partitionId = 0
482495
var indexInPartition = 0L
483-
var batchStreamsRead = 0
496+
var batchId = 0
484497
var indexInBatch = 0
485498
var lastPartitionId = 0
486499

487500
skipToNextPartition()
488501

489-
// An intermediate stream that reads from exactly one batch
502+
503+
// Intermediate file and deserializer streams that read from exactly one batch
490504
// This guards against pre-fetching and other arbitrary behavior of higher level streams
491-
var batchStream = nextBatchStream()
492-
var compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
493-
var deserStream = serInstance.deserializeStream(compressedStream)
505+
var fileStream: FileInputStream = null
506+
var deserializeStream = nextBatchStream() // Also sets fileStream
507+
494508
var nextItem: (K, C) = null
495509
var finished = false
496510

497511
/** Construct a stream that only reads from the next batch */
498-
def nextBatchStream(): InputStream = {
499-
if (batchStreamsRead < spill.serializerBatchSizes.length) {
500-
batchStreamsRead += 1
501-
ByteStreams.limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1))
512+
def nextBatchStream(): DeserializationStream = {
513+
// Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
514+
// we're still in a valid batch.
515+
if (batchId < batchOffsets.length - 1) {
516+
if (deserializeStream != null) {
517+
deserializeStream.close()
518+
fileStream.close()
519+
deserializeStream = null
520+
fileStream = null
521+
}
522+
523+
val start = batchOffsets(batchId)
524+
fileStream = new FileInputStream(spill.file)
525+
fileStream.getChannel.position(start)
526+
batchId += 1
527+
528+
val end = batchOffsets(batchId)
529+
530+
assert(end >= start, "start = " + start + ", end = " + end +
531+
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))
532+
533+
val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
534+
val compressedStream = blockManager.wrapForCompression(spill.blockId, bufferedStream)
535+
serInstance.deserializeStream(compressedStream)
502536
} else {
503-
// No more batches left; give an empty stream
504-
bufferedStream
537+
// No more batches left
538+
cleanup()
539+
null
505540
}
506541
}
507542

@@ -525,27 +560,27 @@ private[spark] class ExternalSorter[K, V, C](
525560
* If no more pairs are left, return null.
526561
*/
527562
private def readNextItem(): (K, C) = {
528-
if (finished) {
563+
if (finished || deserializeStream == null) {
529564
return null
530565
}
531-
val k = deserStream.readObject().asInstanceOf[K]
532-
val c = deserStream.readObject().asInstanceOf[C]
566+
val k = deserializeStream.readObject().asInstanceOf[K]
567+
val c = deserializeStream.readObject().asInstanceOf[C]
533568
lastPartitionId = partitionId
534569
// Start reading the next batch if we're done with this one
535570
indexInBatch += 1
536571
if (indexInBatch == serializerBatchSize) {
537-
batchStream = nextBatchStream()
538-
compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
539-
deserStream = serInstance.deserializeStream(compressedStream)
540572
indexInBatch = 0
573+
deserializeStream = nextBatchStream()
541574
}
542575
// Update the partition location of the element we're reading
543576
indexInPartition += 1
544577
skipToNextPartition()
545578
// If we've finished reading the last partition, remember that we're done
546579
if (partitionId == numPartitions) {
547580
finished = true
548-
deserStream.close()
581+
if (deserializeStream != null) {
582+
deserializeStream.close()
583+
}
549584
}
550585
(k, c)
551586
}
@@ -578,6 +613,17 @@ private[spark] class ExternalSorter[K, V, C](
578613
item
579614
}
580615
}
616+
617+
// Clean up our open streams and put us in a state where we can't read any more data
618+
def cleanup() {
619+
batchId = batchOffsets.length // Prevent reading any other batch
620+
val ds = deserializeStream
621+
deserializeStream = null
622+
fileStream = null
623+
ds.close()
624+
// NOTE: We don't do file.delete() here because that is done in ExternalSorter.stop().
625+
// This should also be fixed in ExternalAppendOnlyMap.
626+
}
581627
}
582628

583629
/**

0 commit comments

Comments
 (0)