@@ -33,7 +33,7 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode}
33
33
import org .apache .spark .serializer .{SerializationStream , SerializerManager }
34
34
import org .apache .spark .storage .{BlockId , BlockInfoManager , StorageLevel }
35
35
import org .apache .spark .unsafe .Platform
36
- import org .apache .spark .util .{CompletionIterator , SizeEstimator , Utils }
36
+ import org .apache .spark .util .{SizeEstimator , Utils }
37
37
import org .apache .spark .util .collection .SizeTrackingVector
38
38
import org .apache .spark .util .io .{ChunkedByteBuffer , ChunkedByteBufferOutputStream }
39
39
@@ -277,6 +277,7 @@ private[spark] class MemoryStore(
277
277
" released too much unroll memory" )
278
278
Left (new PartiallyUnrolledIterator (
279
279
this ,
280
+ MemoryMode .ON_HEAP ,
280
281
unrollMemoryUsedByThisBlock,
281
282
unrolled = arrayValues.toIterator,
282
283
rest = Iterator .empty))
@@ -285,7 +286,11 @@ private[spark] class MemoryStore(
285
286
// We ran out of space while unrolling the values for this block
286
287
logUnrollFailureMessage(blockId, vector.estimateSize())
287
288
Left (new PartiallyUnrolledIterator (
288
- this , unrollMemoryUsedByThisBlock, unrolled = vector.iterator, rest = values))
289
+ this ,
290
+ MemoryMode .ON_HEAP ,
291
+ unrollMemoryUsedByThisBlock,
292
+ unrolled = vector.iterator,
293
+ rest = values))
289
294
}
290
295
}
291
296
@@ -394,7 +399,7 @@ private[spark] class MemoryStore(
394
399
redirectableStream,
395
400
unrollMemoryUsedByThisBlock,
396
401
memoryMode,
397
- bbos.toChunkedByteBuffer ,
402
+ bbos,
398
403
values,
399
404
classTag))
400
405
}
@@ -655,20 +660,22 @@ private[spark] class MemoryStore(
655
660
* The result of a failed [[MemoryStore.putIteratorAsValues() ]] call.
656
661
*
657
662
* @param memoryStore the memoryStore, used for freeing memory.
663
+ * @param memoryMode the memory mode (on- or off-heap).
658
664
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
659
665
* @param unrolled an iterator for the partially-unrolled values.
660
666
* @param rest the rest of the original iterator passed to
661
667
* [[MemoryStore.putIteratorAsValues() ]].
662
668
*/
663
669
private [storage] class PartiallyUnrolledIterator [T ](
664
670
memoryStore : MemoryStore ,
671
+ memoryMode : MemoryMode ,
665
672
unrollMemory : Long ,
666
673
private [this ] var unrolled : Iterator [T ],
667
674
rest : Iterator [T ])
668
675
extends Iterator [T ] {
669
676
670
677
private def releaseUnrollMemory (): Unit = {
671
- memoryStore.releaseUnrollMemoryForThisTask(MemoryMode . ON_HEAP , unrollMemory)
678
+ memoryStore.releaseUnrollMemoryForThisTask(memoryMode , unrollMemory)
672
679
// SPARK-17503: Garbage collects the unrolling memory before the life end of
673
680
// PartiallyUnrolledIterator.
674
681
unrolled = null
@@ -706,7 +713,7 @@ private[storage] class PartiallyUnrolledIterator[T](
706
713
/**
707
714
* A wrapper which allows an open [[OutputStream ]] to be redirected to a different sink.
708
715
*/
709
- private class RedirectableOutputStream extends OutputStream {
716
+ private [storage] class RedirectableOutputStream extends OutputStream {
710
717
private [this ] var os : OutputStream = _
711
718
def setOutputStream (s : OutputStream ): Unit = { os = s }
712
719
override def write (b : Int ): Unit = os.write(b)
@@ -726,7 +733,8 @@ private class RedirectableOutputStream extends OutputStream {
726
733
* @param redirectableOutputStream an OutputStream which can be redirected to a different sink.
727
734
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
728
735
* @param memoryMode whether the unroll memory is on- or off-heap
729
- * @param unrolled a byte buffer containing the partially-serialized values.
736
+ * @param bbos byte buffer output stream containing the partially-serialized values.
737
+ * [[redirectableOutputStream ]] initially points to this output stream.
730
738
* @param rest the rest of the original iterator passed to
731
739
* [[MemoryStore.putIteratorAsValues() ]].
732
740
* @param classTag the [[ClassTag ]] for the block.
@@ -735,14 +743,19 @@ private[storage] class PartiallySerializedBlock[T](
735
743
memoryStore : MemoryStore ,
736
744
serializerManager : SerializerManager ,
737
745
blockId : BlockId ,
738
- serializationStream : SerializationStream ,
739
- redirectableOutputStream : RedirectableOutputStream ,
740
- unrollMemory : Long ,
746
+ private val serializationStream : SerializationStream ,
747
+ private val redirectableOutputStream : RedirectableOutputStream ,
748
+ val unrollMemory : Long ,
741
749
memoryMode : MemoryMode ,
742
- unrolled : ChunkedByteBuffer ,
750
+ bbos : ChunkedByteBufferOutputStream ,
743
751
rest : Iterator [T ],
744
752
classTag : ClassTag [T ]) {
745
753
754
+ private lazy val unrolledBuffer : ChunkedByteBuffer = {
755
+ bbos.close()
756
+ bbos.toChunkedByteBuffer
757
+ }
758
+
746
759
// If the task does not fully consume `valuesIterator` or otherwise fails to consume or dispose of
747
760
// this PartiallySerializedBlock then we risk leaking of direct buffers, so we use a task
748
761
// completion listener here in order to ensure that `unrolled.dispose()` is called at least once.
@@ -751,23 +764,42 @@ private[storage] class PartiallySerializedBlock[T](
751
764
taskContext.addTaskCompletionListener { _ =>
752
765
// When a task completes, its unroll memory will automatically be freed. Thus we do not call
753
766
// releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing.
754
- unrolled.dispose()
767
+ unrolledBuffer.dispose()
768
+ }
769
+ }
770
+
771
+ // Exposed for testing
772
+ private [storage] def getUnrolledChunkedByteBuffer : ChunkedByteBuffer = unrolledBuffer
773
+
774
+ private [this ] var discarded = false
775
+ private [this ] var consumed = false
776
+
777
+ private def verifyNotConsumedAndNotDiscarded (): Unit = {
778
+ if (consumed) {
779
+ throw new IllegalStateException (
780
+ " Can only call one of finishWritingToStream() or valuesIterator() and can only call once." )
781
+ }
782
+ if (discarded) {
783
+ throw new IllegalStateException (" Cannot call methods on a discarded PartiallySerializedBlock" )
755
784
}
756
785
}
757
786
758
787
/**
759
788
* Called to dispose of this block and free its memory.
760
789
*/
761
790
def discard (): Unit = {
762
- try {
763
- // We want to close the output stream in order to free any resources associated with the
764
- // serializer itself (such as Kryo's internal buffers). close() might cause data to be
765
- // written, so redirect the output stream to discard that data.
766
- redirectableOutputStream.setOutputStream(ByteStreams .nullOutputStream())
767
- serializationStream.close()
768
- } finally {
769
- unrolled.dispose()
770
- memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
791
+ if (! discarded) {
792
+ try {
793
+ // We want to close the output stream in order to free any resources associated with the
794
+ // serializer itself (such as Kryo's internal buffers). close() might cause data to be
795
+ // written, so redirect the output stream to discard that data.
796
+ redirectableOutputStream.setOutputStream(ByteStreams .nullOutputStream())
797
+ serializationStream.close()
798
+ } finally {
799
+ discarded = true
800
+ unrolledBuffer.dispose()
801
+ memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
802
+ }
771
803
}
772
804
}
773
805
@@ -776,8 +808,10 @@ private[storage] class PartiallySerializedBlock[T](
776
808
* and then serializing the values from the original input iterator.
777
809
*/
778
810
def finishWritingToStream (os : OutputStream ): Unit = {
811
+ verifyNotConsumedAndNotDiscarded()
812
+ consumed = true
779
813
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
780
- ByteStreams .copy(unrolled .toInputStream(dispose = true ), os)
814
+ ByteStreams .copy(unrolledBuffer .toInputStream(dispose = true ), os)
781
815
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
782
816
redirectableOutputStream.setOutputStream(os)
783
817
while (rest.hasNext) {
@@ -794,13 +828,22 @@ private[storage] class PartiallySerializedBlock[T](
794
828
* `close()` on it to free its resources.
795
829
*/
796
830
def valuesIterator : PartiallyUnrolledIterator [T ] = {
831
+ verifyNotConsumedAndNotDiscarded()
832
+ consumed = true
833
+ // Close the serialization stream so that the serializer's internal buffers are freed and any
834
+ // "end-of-stream" markers can be written out so that `unrolled` is a valid serialized stream.
835
+ serializationStream.close()
797
836
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
798
837
val unrolledIter = serializerManager.dataDeserializeStream(
799
- blockId, unrolled.toInputStream(dispose = true ))(classTag)
838
+ blockId, unrolledBuffer.toInputStream(dispose = true ))(classTag)
839
+ // The unroll memory will be freed once `unrolledIter` is fully consumed in
840
+ // PartiallyUnrolledIterator. If the iterator is not consumed by the end of the task then any
841
+ // extra unroll memory will automatically be freed by a `finally` block in `Task`.
800
842
new PartiallyUnrolledIterator (
801
843
memoryStore,
844
+ memoryMode,
802
845
unrollMemory,
803
- unrolled = CompletionIterator [ T , Iterator [ T ]]( unrolledIter, discard()) ,
846
+ unrolled = unrolledIter,
804
847
rest = rest)
805
848
}
806
849
}
0 commit comments