Skip to content

Commit 8faa521

Browse files
committed
[SPARK-17491] Close serialization stream to fix wrong answer bug in putIteratorAsBytes()
## What changes were proposed in this pull request? `MemoryStore.putIteratorAsBytes()` may silently lose values when used with `KryoSerializer` because it does not properly close the serialization stream before attempting to deserialize the already-serialized values, which may cause values buffered in Kryo's internal buffers to not be read. This is the root cause behind a user-reported "wrong answer" bug in PySpark caching reported by bennoleslie on the Spark user mailing list in a thread titled "pyspark persist MEMORY_ONLY vs MEMORY_AND_DISK". Due to Spark 2.0's automatic use of KryoSerializer for "safe" types (such as byte arrays, primitives, etc.) this misuse of serializers manifested itself as silent data corruption rather than a StreamCorrupted error (which you might get from JavaSerializer). The minimal fix, implemented here, is to close the serialization stream before attempting to deserialize written values. In addition, this patch adds several additional assertions / precondition checks to prevent misuse of `PartiallySerializedBlock` and `ChunkedByteBufferOutputStream`. ## How was this patch tested? The original bug was masked by an invalid assert in the memory store test cases: the old assert compared two results record-by-record with `zip` but didn't first check that the lengths of the two collections were equal, causing missing records to go unnoticed. The updated test case reproduced this bug. In addition, I added a new `PartiallySerializedBlockSuite` to unit test that component. Author: Josh Rosen <joshrosen@databricks.com> Closes #15043 from JoshRosen/partially-serialized-block-values-iterator-bugfix.
1 parent 86c2d39 commit 8faa521

File tree

8 files changed

+344
-44
lines changed

8 files changed

+344
-44
lines changed

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ private[spark] object Task {
230230
dataOut.flush()
231231
val taskBytes = serializer.serialize(task)
232232
Utils.writeByteBuffer(taskBytes, out)
233+
out.close()
233234
out.toByteBuffer
234235
}
235236

core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode}
3333
import org.apache.spark.serializer.{SerializationStream, SerializerManager}
3434
import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel}
3535
import org.apache.spark.unsafe.Platform
36-
import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils}
36+
import org.apache.spark.util.{SizeEstimator, Utils}
3737
import org.apache.spark.util.collection.SizeTrackingVector
3838
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
3939

@@ -277,6 +277,7 @@ private[spark] class MemoryStore(
277277
"released too much unroll memory")
278278
Left(new PartiallyUnrolledIterator(
279279
this,
280+
MemoryMode.ON_HEAP,
280281
unrollMemoryUsedByThisBlock,
281282
unrolled = arrayValues.toIterator,
282283
rest = Iterator.empty))
@@ -285,7 +286,11 @@ private[spark] class MemoryStore(
285286
// We ran out of space while unrolling the values for this block
286287
logUnrollFailureMessage(blockId, vector.estimateSize())
287288
Left(new PartiallyUnrolledIterator(
288-
this, unrollMemoryUsedByThisBlock, unrolled = vector.iterator, rest = values))
289+
this,
290+
MemoryMode.ON_HEAP,
291+
unrollMemoryUsedByThisBlock,
292+
unrolled = vector.iterator,
293+
rest = values))
289294
}
290295
}
291296

@@ -394,7 +399,7 @@ private[spark] class MemoryStore(
394399
redirectableStream,
395400
unrollMemoryUsedByThisBlock,
396401
memoryMode,
397-
bbos.toChunkedByteBuffer,
402+
bbos,
398403
values,
399404
classTag))
400405
}
@@ -655,20 +660,22 @@ private[spark] class MemoryStore(
655660
* The result of a failed [[MemoryStore.putIteratorAsValues()]] call.
656661
*
657662
* @param memoryStore the memoryStore, used for freeing memory.
663+
* @param memoryMode the memory mode (on- or off-heap).
658664
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
659665
* @param unrolled an iterator for the partially-unrolled values.
660666
* @param rest the rest of the original iterator passed to
661667
* [[MemoryStore.putIteratorAsValues()]].
662668
*/
663669
private[storage] class PartiallyUnrolledIterator[T](
664670
memoryStore: MemoryStore,
671+
memoryMode: MemoryMode,
665672
unrollMemory: Long,
666673
private[this] var unrolled: Iterator[T],
667674
rest: Iterator[T])
668675
extends Iterator[T] {
669676

670677
private def releaseUnrollMemory(): Unit = {
671-
memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
678+
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
672679
// SPARK-17503: Garbage collects the unrolling memory before the life end of
673680
// PartiallyUnrolledIterator.
674681
unrolled = null
@@ -706,7 +713,7 @@ private[storage] class PartiallyUnrolledIterator[T](
706713
/**
707714
* A wrapper which allows an open [[OutputStream]] to be redirected to a different sink.
708715
*/
709-
private class RedirectableOutputStream extends OutputStream {
716+
private[storage] class RedirectableOutputStream extends OutputStream {
710717
private[this] var os: OutputStream = _
711718
def setOutputStream(s: OutputStream): Unit = { os = s }
712719
override def write(b: Int): Unit = os.write(b)
@@ -726,7 +733,8 @@ private class RedirectableOutputStream extends OutputStream {
726733
* @param redirectableOutputStream an OutputStream which can be redirected to a different sink.
727734
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
728735
* @param memoryMode whether the unroll memory is on- or off-heap
729-
* @param unrolled a byte buffer containing the partially-serialized values.
736+
* @param bbos byte buffer output stream containing the partially-serialized values.
737+
* [[redirectableOutputStream]] initially points to this output stream.
730738
* @param rest the rest of the original iterator passed to
731739
* [[MemoryStore.putIteratorAsValues()]].
732740
* @param classTag the [[ClassTag]] for the block.
@@ -735,14 +743,19 @@ private[storage] class PartiallySerializedBlock[T](
735743
memoryStore: MemoryStore,
736744
serializerManager: SerializerManager,
737745
blockId: BlockId,
738-
serializationStream: SerializationStream,
739-
redirectableOutputStream: RedirectableOutputStream,
740-
unrollMemory: Long,
746+
private val serializationStream: SerializationStream,
747+
private val redirectableOutputStream: RedirectableOutputStream,
748+
val unrollMemory: Long,
741749
memoryMode: MemoryMode,
742-
unrolled: ChunkedByteBuffer,
750+
bbos: ChunkedByteBufferOutputStream,
743751
rest: Iterator[T],
744752
classTag: ClassTag[T]) {
745753

754+
private lazy val unrolledBuffer: ChunkedByteBuffer = {
755+
bbos.close()
756+
bbos.toChunkedByteBuffer
757+
}
758+
746759
// If the task does not fully consume `valuesIterator` or otherwise fails to consume or dispose of
747760
// this PartiallySerializedBlock then we risk leaking of direct buffers, so we use a task
748761
// completion listener here in order to ensure that `unrolled.dispose()` is called at least once.
@@ -751,23 +764,42 @@ private[storage] class PartiallySerializedBlock[T](
751764
taskContext.addTaskCompletionListener { _ =>
752765
// When a task completes, its unroll memory will automatically be freed. Thus we do not call
753766
// releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing.
754-
unrolled.dispose()
767+
unrolledBuffer.dispose()
768+
}
769+
}
770+
771+
// Exposed for testing
772+
private[storage] def getUnrolledChunkedByteBuffer: ChunkedByteBuffer = unrolledBuffer
773+
774+
private[this] var discarded = false
775+
private[this] var consumed = false
776+
777+
private def verifyNotConsumedAndNotDiscarded(): Unit = {
778+
if (consumed) {
779+
throw new IllegalStateException(
780+
"Can only call one of finishWritingToStream() or valuesIterator() and can only call once.")
781+
}
782+
if (discarded) {
783+
throw new IllegalStateException("Cannot call methods on a discarded PartiallySerializedBlock")
755784
}
756785
}
757786

758787
/**
759788
* Called to dispose of this block and free its memory.
760789
*/
761790
def discard(): Unit = {
762-
try {
763-
// We want to close the output stream in order to free any resources associated with the
764-
// serializer itself (such as Kryo's internal buffers). close() might cause data to be
765-
// written, so redirect the output stream to discard that data.
766-
redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
767-
serializationStream.close()
768-
} finally {
769-
unrolled.dispose()
770-
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
791+
if (!discarded) {
792+
try {
793+
// We want to close the output stream in order to free any resources associated with the
794+
// serializer itself (such as Kryo's internal buffers). close() might cause data to be
795+
// written, so redirect the output stream to discard that data.
796+
redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
797+
serializationStream.close()
798+
} finally {
799+
discarded = true
800+
unrolledBuffer.dispose()
801+
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
802+
}
771803
}
772804
}
773805

@@ -776,8 +808,10 @@ private[storage] class PartiallySerializedBlock[T](
776808
* and then serializing the values from the original input iterator.
777809
*/
778810
def finishWritingToStream(os: OutputStream): Unit = {
811+
verifyNotConsumedAndNotDiscarded()
812+
consumed = true
779813
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
780-
ByteStreams.copy(unrolled.toInputStream(dispose = true), os)
814+
ByteStreams.copy(unrolledBuffer.toInputStream(dispose = true), os)
781815
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
782816
redirectableOutputStream.setOutputStream(os)
783817
while (rest.hasNext) {
@@ -794,13 +828,22 @@ private[storage] class PartiallySerializedBlock[T](
794828
* `close()` on it to free its resources.
795829
*/
796830
def valuesIterator: PartiallyUnrolledIterator[T] = {
831+
verifyNotConsumedAndNotDiscarded()
832+
consumed = true
833+
// Close the serialization stream so that the serializer's internal buffers are freed and any
834+
// "end-of-stream" markers can be written out so that `unrolled` is a valid serialized stream.
835+
serializationStream.close()
797836
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
798837
val unrolledIter = serializerManager.dataDeserializeStream(
799-
blockId, unrolled.toInputStream(dispose = true))(classTag)
838+
blockId, unrolledBuffer.toInputStream(dispose = true))(classTag)
839+
// The unroll memory will be freed once `unrolledIter` is fully consumed in
840+
// PartiallyUnrolledIterator. If the iterator is not consumed by the end of the task then any
841+
// extra unroll memory will automatically be freed by a `finally` block in `Task`.
800842
new PartiallyUnrolledIterator(
801843
memoryStore,
844+
memoryMode,
802845
unrollMemory,
803-
unrolled = CompletionIterator[T, Iterator[T]](unrolledIter, discard()),
846+
unrolled = unrolledIter,
804847
rest = rest)
805848
}
806849
}

core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,32 @@ private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutp
2929

3030
def getCount(): Int = count
3131

32+
private[this] var closed: Boolean = false
33+
34+
override def write(b: Int): Unit = {
35+
require(!closed, "cannot write to a closed ByteBufferOutputStream")
36+
super.write(b)
37+
}
38+
39+
override def write(b: Array[Byte], off: Int, len: Int): Unit = {
40+
require(!closed, "cannot write to a closed ByteBufferOutputStream")
41+
super.write(b, off, len)
42+
}
43+
44+
override def reset(): Unit = {
45+
require(!closed, "cannot reset a closed ByteBufferOutputStream")
46+
super.reset()
47+
}
48+
49+
override def close(): Unit = {
50+
if (!closed) {
51+
super.close()
52+
closed = true
53+
}
54+
}
55+
3256
def toByteBuffer: ByteBuffer = {
33-
return ByteBuffer.wrap(buf, 0, count)
57+
require(closed, "can only call toByteBuffer() after ByteBufferOutputStream has been closed")
58+
ByteBuffer.wrap(buf, 0, count)
3459
}
3560
}

core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,27 @@ private[spark] class ChunkedByteBufferOutputStream(
4949
*/
5050
private[this] var position = chunkSize
5151
private[this] var _size = 0
52+
private[this] var closed: Boolean = false
5253

5354
def size: Long = _size
5455

56+
override def close(): Unit = {
57+
if (!closed) {
58+
super.close()
59+
closed = true
60+
}
61+
}
62+
5563
override def write(b: Int): Unit = {
64+
require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
5665
allocateNewChunkIfNeeded()
5766
chunks(lastChunkIndex).put(b.toByte)
5867
position += 1
5968
_size += 1
6069
}
6170

6271
override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
72+
require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
6373
var written = 0
6474
while (written < len) {
6575
allocateNewChunkIfNeeded()
@@ -73,7 +83,6 @@ private[spark] class ChunkedByteBufferOutputStream(
7383

7484
@inline
7585
private def allocateNewChunkIfNeeded(): Unit = {
76-
require(!toChunkedByteBufferWasCalled, "cannot write after toChunkedByteBuffer() is called")
7786
if (position == chunkSize) {
7887
chunks += allocator(chunkSize)
7988
lastChunkIndex += 1
@@ -82,6 +91,7 @@ private[spark] class ChunkedByteBufferOutputStream(
8291
}
8392

8493
def toChunkedByteBuffer: ChunkedByteBuffer = {
94+
require(closed, "cannot call toChunkedByteBuffer() unless close() has been called")
8595
require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be called once")
8696
toChunkedByteBufferWasCalled = true
8797
if (lastChunkIndex == -1) {

core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@ class MemoryStoreSuite
7979
(memoryStore, blockInfoManager)
8080
}
8181

82+
private def assertSameContents[T](expected: Seq[T], actual: Seq[T], hint: String): Unit = {
83+
assert(actual.length === expected.length, s"wrong number of values returned in $hint")
84+
expected.iterator.zip(actual.iterator).foreach { case (e, a) =>
85+
assert(e === a, s"$hint did not return original values!")
86+
}
87+
}
88+
8289
test("reserve/release unroll memory") {
8390
val (memoryStore, _) = makeMemoryStore(12000)
8491
assert(memoryStore.currentUnrollMemory === 0)
@@ -137,9 +144,7 @@ class MemoryStoreSuite
137144
var putResult = putIteratorAsValues("unroll", smallList.iterator, ClassTag.Any)
138145
assert(putResult.isRight)
139146
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
140-
smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
141-
assert(e === a, "getValues() did not return original values!")
142-
}
147+
assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues")
143148
blockInfoManager.lockForWriting("unroll")
144149
assert(memoryStore.remove("unroll"))
145150
blockInfoManager.removeBlock("unroll")
@@ -152,9 +157,7 @@ class MemoryStoreSuite
152157
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
153158
assert(memoryStore.contains("someBlock2"))
154159
assert(!memoryStore.contains("someBlock1"))
155-
smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
156-
assert(e === a, "getValues() did not return original values!")
157-
}
160+
assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues")
158161
blockInfoManager.lockForWriting("unroll")
159162
assert(memoryStore.remove("unroll"))
160163
blockInfoManager.removeBlock("unroll")
@@ -167,9 +170,7 @@ class MemoryStoreSuite
167170
assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
168171
assert(!memoryStore.contains("someBlock2"))
169172
assert(putResult.isLeft)
170-
bigList.iterator.zip(putResult.left.get).foreach { case (e, a) =>
171-
assert(e === a, "putIterator() did not return original values!")
172-
}
173+
assertSameContents(bigList, putResult.left.get.toSeq, "putIterator")
173174
// The unroll memory was freed once the iterator returned by putIterator() was fully traversed.
174175
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
175176
}
@@ -316,9 +317,8 @@ class MemoryStoreSuite
316317
assert(res.isLeft)
317318
assert(memoryStore.currentUnrollMemoryForThisTask > 0)
318319
val valuesReturnedFromFailedPut = res.left.get.valuesIterator.toSeq // force materialization
319-
valuesReturnedFromFailedPut.zip(bigList).foreach { case (e, a) =>
320-
assert(e === a, "PartiallySerializedBlock.valuesIterator() did not return original values!")
321-
}
320+
assertSameContents(
321+
bigList, valuesReturnedFromFailedPut, "PartiallySerializedBlock.valuesIterator()")
322322
// The unroll memory was freed once the iterator was fully traversed.
323323
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
324324
}
@@ -340,12 +340,10 @@ class MemoryStoreSuite
340340
res.left.get.finishWritingToStream(bos)
341341
// The unroll memory was freed once the block was fully written.
342342
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
343-
val deserializationStream = serializerManager.dataDeserializeStream[Any](
344-
"b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any)
345-
deserializationStream.zip(bigList.iterator).foreach { case (e, a) =>
346-
assert(e === a,
347-
"PartiallySerializedBlock.finishWritingtoStream() did not write original values!")
348-
}
343+
val deserializedValues = serializerManager.dataDeserializeStream[Any](
344+
"b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any).toSeq
345+
assertSameContents(
346+
bigList, deserializedValues, "PartiallySerializedBlock.finishWritingToStream()")
349347
}
350348

351349
test("multiple unrolls by the same thread") {

0 commit comments

Comments
 (0)