Skip to content

Commit 9a78e4b

Browse files
committed
Add @mridulm's fixes to ExternalAppendOnlyMap for batch sizes
All these changes are from @mridulm's work in apache#1609, but extracted here to fix this specific issue. This particular set of changes is to make sure that we read exactly the right range of bytes from each spill file in EAOM: some serializers can write bytes after the last object (e.g. the TC_RESET flag in Java serialization) and that would confuse the previous code into reading it as part of the next batch. There are also improvements to cleanup to make sure files are closed.
1 parent 78f2af5 commit 9a78e4b

File tree

1 file changed

+77
-20
lines changed

1 file changed

+77
-20
lines changed

core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala

Lines changed: 77 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.util.collection
1919

20-
import java.io.{InputStream, BufferedInputStream, FileInputStream, File, Serializable, EOFException}
20+
import java.io._
2121
import java.util.Comparator
2222

2323
import scala.collection.BufferedIterator
@@ -28,7 +28,7 @@ import com.google.common.io.ByteStreams
2828

2929
import org.apache.spark.{Logging, SparkEnv}
3030
import org.apache.spark.annotation.DeveloperApi
31-
import org.apache.spark.serializer.Serializer
31+
import org.apache.spark.serializer.{DeserializationStream, Serializer}
3232
import org.apache.spark.storage.{BlockId, BlockManager}
3333
import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator
3434

@@ -199,13 +199,16 @@ class ExternalAppendOnlyMap[K, V, C](
199199

200200
// Flush the disk writer's contents to disk, and update relevant variables
201201
def flush() = {
202-
writer.commitAndClose()
203-
val bytesWritten = writer.bytesWritten
202+
val w = writer
203+
writer = null
204+
w.commitAndClose()
205+
val bytesWritten = w.bytesWritten
204206
batchSizes.append(bytesWritten)
205207
_diskBytesSpilled += bytesWritten
206208
objectsWritten = 0
207209
}
208210

211+
var success = false
209212
try {
210213
val it = currentMap.destructiveSortedIterator(keyComparator)
211214
while (it.hasNext) {
@@ -215,16 +218,28 @@ class ExternalAppendOnlyMap[K, V, C](
215218

216219
if (objectsWritten == serializerBatchSize) {
217220
flush()
218-
writer.close()
219221
writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
220222
}
221223
}
222224
if (objectsWritten > 0) {
223225
flush()
226+
} else if (writer != null) {
227+
val w = writer
228+
writer = null
229+
w.revertPartialWritesAndClose()
224230
}
231+
success = true
225232
} finally {
226-
// Partial failures cannot be tolerated; do not revert partial writes
227-
writer.close()
233+
if (!success) {
234+
// This code path only happens if an exception was thrown above before we set success;
235+
// close our stuff and let the exception be thrown further
236+
if (writer != null) {
237+
writer.revertPartialWritesAndClose()
238+
}
239+
if (file.exists()) {
240+
file.delete()
241+
}
242+
}
228243
}
229244

230245
currentMap = new SizeTrackingAppendOnlyMap[K, C]
@@ -390,26 +405,49 @@ class ExternalAppendOnlyMap[K, V, C](
390405
*/
391406
private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long])
392407
extends Iterator[(K, C)] {
393-
private val fileStream = new FileInputStream(file)
394-
private val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)
408+
private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1
409+
assert(file.length() == batchOffsets(batchOffsets.length - 1))
410+
411+
private var batchIndex = 0 // Which batch we're in
412+
private var fileStream: FileInputStream = null
395413

396414
// An intermediate stream that reads from exactly one batch
397415
// This guards against pre-fetching and other arbitrary behavior of higher level streams
398-
private var batchStream = nextBatchStream()
399-
private var compressedStream = blockManager.wrapForCompression(blockId, batchStream)
400-
private var deserializeStream = ser.deserializeStream(compressedStream)
416+
private var deserializeStream = nextBatchStream()
401417
private var nextItem: (K, C) = null
402418
private var objectsRead = 0
403419

404420
/**
405421
* Construct a stream that reads only from the next batch.
406422
*/
407-
private def nextBatchStream(): InputStream = {
408-
if (batchSizes.length > 0) {
409-
ByteStreams.limit(bufferedStream, batchSizes.remove(0))
423+
private def nextBatchStream(): DeserializationStream = {
424+
// Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
425+
// we're still in a valid batch.
426+
if (batchIndex < batchOffsets.length - 1) {
427+
if (deserializeStream != null) {
428+
deserializeStream.close()
429+
fileStream.close()
430+
deserializeStream = null
431+
fileStream = null
432+
}
433+
434+
val start = batchOffsets(batchIndex)
435+
fileStream = new FileInputStream(file)
436+
fileStream.getChannel.position(start)
437+
batchIndex += 1
438+
439+
val end = batchOffsets(batchIndex)
440+
441+
assert(end >= start, "start = " + start + ", end = " + end +
442+
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))
443+
444+
val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
445+
val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream)
446+
ser.deserializeStream(compressedStream)
410447
} else {
411448
// No more batches left
412-
bufferedStream
449+
cleanup()
450+
null
413451
}
414452
}
415453

@@ -424,10 +462,8 @@ class ExternalAppendOnlyMap[K, V, C](
424462
val item = deserializeStream.readObject().asInstanceOf[(K, C)]
425463
objectsRead += 1
426464
if (objectsRead == serializerBatchSize) {
427-
batchStream = nextBatchStream()
428-
compressedStream = blockManager.wrapForCompression(blockId, batchStream)
429-
deserializeStream = ser.deserializeStream(compressedStream)
430465
objectsRead = 0
466+
deserializeStream = nextBatchStream()
431467
}
432468
item
433469
} catch {
@@ -439,6 +475,9 @@ class ExternalAppendOnlyMap[K, V, C](
439475

440476
override def hasNext: Boolean = {
441477
if (nextItem == null) {
478+
if (deserializeStream == null) {
479+
return false
480+
}
442481
nextItem = readNextItem()
443482
}
444483
nextItem != null
@@ -455,7 +494,25 @@ class ExternalAppendOnlyMap[K, V, C](
455494

456495
// TODO: Ensure this gets called even if the iterator isn't drained.
457496
private def cleanup() {
458-
deserializeStream.close()
497+
batchIndex = batchOffsets.length // Prevent reading any other batch
498+
val ds = deserializeStream
499+
val fs = fileStream
500+
deserializeStream = null
501+
fileStream = null
502+
503+
if (ds != null) {
504+
try {
505+
ds.close()
506+
} catch {
507+
case e: IOException =>
508+
// Make sure we at least close the file handle
509+
if (fs != null) {
510+
try { fs.close() } catch { case e2: IOException => }
511+
}
512+
throw e
513+
}
514+
}
515+
459516
file.delete()
460517
}
461518
}

0 commit comments

Comments
 (0)