1717
1818package org .apache .spark .util .collection
1919
20- import java .io .{ InputStream , BufferedInputStream , FileInputStream , File , Serializable , EOFException }
20+ import java .io ._
2121import java .util .Comparator
2222
2323import scala .collection .BufferedIterator
@@ -28,7 +28,7 @@ import com.google.common.io.ByteStreams
2828
2929import org .apache .spark .{Logging , SparkEnv }
3030import org .apache .spark .annotation .DeveloperApi
31- import org .apache .spark .serializer .Serializer
31+ import org .apache .spark .serializer .{ DeserializationStream , Serializer }
3232import org .apache .spark .storage .{BlockId , BlockManager }
3333import org .apache .spark .util .collection .ExternalAppendOnlyMap .HashComparator
3434
@@ -199,13 +199,16 @@ class ExternalAppendOnlyMap[K, V, C](
199199
200200 // Flush the disk writer's contents to disk, and update relevant variables
201201 def flush () = {
202- writer.commitAndClose()
203- val bytesWritten = writer.bytesWritten
202+ val w = writer
203+ writer = null
204+ w.commitAndClose()
205+ val bytesWritten = w.bytesWritten
204206 batchSizes.append(bytesWritten)
205207 _diskBytesSpilled += bytesWritten
206208 objectsWritten = 0
207209 }
208210
211+ var success = false
209212 try {
210213 val it = currentMap.destructiveSortedIterator(keyComparator)
211214 while (it.hasNext) {
@@ -215,16 +218,28 @@ class ExternalAppendOnlyMap[K, V, C](
215218
216219 if (objectsWritten == serializerBatchSize) {
217220 flush()
218- writer.close()
219221 writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
220222 }
221223 }
222224 if (objectsWritten > 0 ) {
223225 flush()
226+ } else if (writer != null ) {
227+ val w = writer
228+ writer = null
229+ w.revertPartialWritesAndClose()
224230 }
231+ success = true
225232 } finally {
226- // Partial failures cannot be tolerated; do not revert partial writes
227- writer.close()
233+ if (! success) {
234+ // This code path only happens if an exception was thrown above before we set success;
235+ // close our stuff and let the exception be thrown further
236+ if (writer != null ) {
237+ writer.revertPartialWritesAndClose()
238+ }
239+ if (file.exists()) {
240+ file.delete()
241+ }
242+ }
228243 }
229244
230245 currentMap = new SizeTrackingAppendOnlyMap [K , C ]
@@ -390,26 +405,49 @@ class ExternalAppendOnlyMap[K, V, C](
390405 */
391406 private class DiskMapIterator (file : File , blockId : BlockId , batchSizes : ArrayBuffer [Long ])
392407 extends Iterator [(K , C )] {
393- private val fileStream = new FileInputStream (file)
394- private val bufferedStream = new BufferedInputStream (fileStream, fileBufferSize)
408+ private val batchOffsets = batchSizes.scanLeft(0L )(_ + _) // Size will be batchSize.length + 1
409+ assert(file.length() == batchOffsets(batchOffsets.length - 1 ))
410+
411+ private var batchIndex = 0 // Which batch we're in
412+ private var fileStream : FileInputStream = null
395413
396414 // An intermediate stream that reads from exactly one batch
397415 // This guards against pre-fetching and other arbitrary behavior of higher level streams
398- private var batchStream = nextBatchStream()
399- private var compressedStream = blockManager.wrapForCompression(blockId, batchStream)
400- private var deserializeStream = ser.deserializeStream(compressedStream)
416+ private var deserializeStream = nextBatchStream()
401417 private var nextItem : (K , C ) = null
402418 private var objectsRead = 0
403419
404420 /**
405421 * Construct a stream that reads only from the next batch.
406422 */
407- private def nextBatchStream (): InputStream = {
408- if (batchSizes.length > 0 ) {
409- ByteStreams .limit(bufferedStream, batchSizes.remove(0 ))
423+ private def nextBatchStream (): DeserializationStream = {
424+ // Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
425+ // we're still in a valid batch.
426+ if (batchIndex < batchOffsets.length - 1 ) {
427+ if (deserializeStream != null ) {
428+ deserializeStream.close()
429+ fileStream.close()
430+ deserializeStream = null
431+ fileStream = null
432+ }
433+
434+ val start = batchOffsets(batchIndex)
435+ fileStream = new FileInputStream (file)
436+ fileStream.getChannel.position(start)
437+ batchIndex += 1
438+
439+ val end = batchOffsets(batchIndex)
440+
441+ assert(end >= start, " start = " + start + " , end = " + end +
442+ " , batchOffsets = " + batchOffsets.mkString(" [" , " , " , " ]" ))
443+
444+ val bufferedStream = new BufferedInputStream (ByteStreams .limit(fileStream, end - start))
445+ val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream)
446+ ser.deserializeStream(compressedStream)
410447 } else {
411448 // No more batches left
412- bufferedStream
449+ cleanup()
450+ null
413451 }
414452 }
415453
@@ -424,10 +462,8 @@ class ExternalAppendOnlyMap[K, V, C](
424462 val item = deserializeStream.readObject().asInstanceOf [(K , C )]
425463 objectsRead += 1
426464 if (objectsRead == serializerBatchSize) {
427- batchStream = nextBatchStream()
428- compressedStream = blockManager.wrapForCompression(blockId, batchStream)
429- deserializeStream = ser.deserializeStream(compressedStream)
430465 objectsRead = 0
466+ deserializeStream = nextBatchStream()
431467 }
432468 item
433469 } catch {
@@ -439,6 +475,9 @@ class ExternalAppendOnlyMap[K, V, C](
439475
440476 override def hasNext : Boolean = {
441477 if (nextItem == null ) {
478+ if (deserializeStream == null ) {
479+ return false
480+ }
442481 nextItem = readNextItem()
443482 }
444483 nextItem != null
@@ -455,7 +494,25 @@ class ExternalAppendOnlyMap[K, V, C](
455494
456495 // TODO: Ensure this gets called even if the iterator isn't drained.
457496 private def cleanup () {
458- deserializeStream.close()
497+ batchIndex = batchOffsets.length // Prevent reading any other batch
498+ val ds = deserializeStream
499+ val fs = fileStream
500+ deserializeStream = null
501+ fileStream = null
502+
503+ if (ds != null ) {
504+ try {
505+ ds.close()
506+ } catch {
507+ case e : IOException =>
508+ // Make sure we at least close the file handle
509+ if (fs != null ) {
510+ try { fs.close() } catch { case e2 : IOException => }
511+ }
512+ throw e
513+ }
514+ }
515+
459516 file.delete()
460517 }
461518 }
0 commit comments