Skip to content

[SPARK-17491] Close serialization stream to fix wrong answer bug in putIteratorAsBytes() #15043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ private[spark] object Task {
dataOut.flush()
val taskBytes = serializer.serialize(task)
Utils.writeByteBuffer(taskBytes, out)
out.close()
out.toByteBuffer
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode}
import org.apache.spark.serializer.{SerializationStream, SerializerManager}
import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel}
import org.apache.spark.unsafe.Platform
import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils}
import org.apache.spark.util.{SizeEstimator, Utils}
import org.apache.spark.util.collection.SizeTrackingVector
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

Expand Down Expand Up @@ -277,6 +277,7 @@ private[spark] class MemoryStore(
"released too much unroll memory")
Left(new PartiallyUnrolledIterator(
this,
MemoryMode.ON_HEAP,
unrollMemoryUsedByThisBlock,
unrolled = arrayValues.toIterator,
rest = Iterator.empty))
Expand All @@ -285,7 +286,11 @@ private[spark] class MemoryStore(
// We ran out of space while unrolling the values for this block
logUnrollFailureMessage(blockId, vector.estimateSize())
Left(new PartiallyUnrolledIterator(
this, unrollMemoryUsedByThisBlock, unrolled = vector.iterator, rest = values))
this,
MemoryMode.ON_HEAP,
unrollMemoryUsedByThisBlock,
unrolled = vector.iterator,
rest = values))
}
}

Expand Down Expand Up @@ -394,7 +399,7 @@ private[spark] class MemoryStore(
redirectableStream,
unrollMemoryUsedByThisBlock,
memoryMode,
bbos.toChunkedByteBuffer,
bbos,
values,
classTag))
}
Expand Down Expand Up @@ -655,20 +660,22 @@ private[spark] class MemoryStore(
* The result of a failed [[MemoryStore.putIteratorAsValues()]] call.
*
* @param memoryStore the memoryStore, used for freeing memory.
* @param memoryMode the memory mode (on- or off-heap).
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
* @param unrolled an iterator for the partially-unrolled values.
* @param rest the rest of the original iterator passed to
* [[MemoryStore.putIteratorAsValues()]].
*/
private[storage] class PartiallyUnrolledIterator[T](
memoryStore: MemoryStore,
memoryMode: MemoryMode,
unrollMemory: Long,
private[this] var unrolled: Iterator[T],
rest: Iterator[T])
extends Iterator[T] {

private def releaseUnrollMemory(): Unit = {
memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
// SPARK-17503: Garbage collects the unrolling memory before the life end of
// PartiallyUnrolledIterator.
unrolled = null
Expand Down Expand Up @@ -706,7 +713,7 @@ private[storage] class PartiallyUnrolledIterator[T](
/**
* A wrapper which allows an open [[OutputStream]] to be redirected to a different sink.
*/
private class RedirectableOutputStream extends OutputStream {
private[storage] class RedirectableOutputStream extends OutputStream {
private[this] var os: OutputStream = _
def setOutputStream(s: OutputStream): Unit = { os = s }
override def write(b: Int): Unit = os.write(b)
Expand All @@ -726,7 +733,8 @@ private class RedirectableOutputStream extends OutputStream {
* @param redirectableOutputStream an OutputStream which can be redirected to a different sink.
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
* @param memoryMode whether the unroll memory is on- or off-heap
* @param unrolled a byte buffer containing the partially-serialized values.
* @param bbos byte buffer output stream containing the partially-serialized values.
* [[redirectableOutputStream]] initially points to this output stream.
* @param rest the rest of the original iterator passed to
* [[MemoryStore.putIteratorAsValues()]].
* @param classTag the [[ClassTag]] for the block.
Expand All @@ -735,14 +743,19 @@ private[storage] class PartiallySerializedBlock[T](
memoryStore: MemoryStore,
serializerManager: SerializerManager,
blockId: BlockId,
serializationStream: SerializationStream,
redirectableOutputStream: RedirectableOutputStream,
unrollMemory: Long,
private val serializationStream: SerializationStream,
private val redirectableOutputStream: RedirectableOutputStream,
val unrollMemory: Long,
memoryMode: MemoryMode,
unrolled: ChunkedByteBuffer,
bbos: ChunkedByteBufferOutputStream,
rest: Iterator[T],
classTag: ClassTag[T]) {

private lazy val unrolledBuffer: ChunkedByteBuffer = {
bbos.close()
bbos.toChunkedByteBuffer
}

// If the task does not fully consume `valuesIterator` or otherwise fails to consume or dispose of
// this PartiallySerializedBlock then we risk leaking of direct buffers, so we use a task
// completion listener here in order to ensure that `unrolled.dispose()` is called at least once.
Expand All @@ -751,23 +764,42 @@ private[storage] class PartiallySerializedBlock[T](
taskContext.addTaskCompletionListener { _ =>
// When a task completes, its unroll memory will automatically be freed. Thus we do not call
// releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing.
unrolled.dispose()
unrolledBuffer.dispose()
}
}

// Exposed for testing
private[storage] def getUnrolledChunkedByteBuffer: ChunkedByteBuffer = unrolledBuffer

private[this] var discarded = false
private[this] var consumed = false

private def verifyNotConsumedAndNotDiscarded(): Unit = {
if (consumed) {
throw new IllegalStateException(
"Can only call one of finishWritingToStream() or valuesIterator() and can only call once.")
}
if (discarded) {
throw new IllegalStateException("Cannot call methods on a discarded PartiallySerializedBlock")
}
}

/**
* Called to dispose of this block and free its memory.
*/
def discard(): Unit = {
try {
// We want to close the output stream in order to free any resources associated with the
// serializer itself (such as Kryo's internal buffers). close() might cause data to be
// written, so redirect the output stream to discard that data.
redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
serializationStream.close()
} finally {
unrolled.dispose()
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
if (!discarded) {
try {
// We want to close the output stream in order to free any resources associated with the
// serializer itself (such as Kryo's internal buffers). close() might cause data to be
// written, so redirect the output stream to discard that data.
redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
serializationStream.close()
} finally {
discarded = true
unrolledBuffer.dispose()
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
}
}
}

Expand All @@ -776,8 +808,10 @@ private[storage] class PartiallySerializedBlock[T](
* and then serializing the values from the original input iterator.
*/
def finishWritingToStream(os: OutputStream): Unit = {
verifyNotConsumedAndNotDiscarded()
consumed = true
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
ByteStreams.copy(unrolled.toInputStream(dispose = true), os)
ByteStreams.copy(unrolledBuffer.toInputStream(dispose = true), os)
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
redirectableOutputStream.setOutputStream(os)
while (rest.hasNext) {
Expand All @@ -794,13 +828,22 @@ private[storage] class PartiallySerializedBlock[T](
* `close()` on it to free its resources.
*/
def valuesIterator: PartiallyUnrolledIterator[T] = {
verifyNotConsumedAndNotDiscarded()
consumed = true
// Close the serialization stream so that the serializer's internal buffers are freed and any
// "end-of-stream" markers can be written out so that `unrolled` is a valid serialized stream.
serializationStream.close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like 'unrolled' may basically be invalid until serializationStream is called.

But it looks like valuesIterator is not the only place where unrolled is used

// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
val unrolledIter = serializerManager.dataDeserializeStream(
blockId, unrolled.toInputStream(dispose = true))(classTag)
blockId, unrolledBuffer.toInputStream(dispose = true))(classTag)
// The unroll memory will be freed once `unrolledIter` is fully consumed in
// PartiallyUnrolledIterator. If the iterator is not consumed by the end of the task then any
// extra unroll memory will automatically be freed by a `finally` block in `Task`.
new PartiallyUnrolledIterator(
memoryStore,
memoryMode,
unrollMemory,
unrolled = CompletionIterator[T, Iterator[T]](unrolledIter, discard()),
unrolled = unrolledIter,
rest = rest)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,32 @@ private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutp

def getCount(): Int = count

private[this] var closed: Boolean = false

override def write(b: Int): Unit = {
require(!closed, "cannot write to a closed ByteBufferOutputStream")
super.write(b)
}

override def write(b: Array[Byte], off: Int, len: Int): Unit = {
require(!closed, "cannot write to a closed ByteBufferOutputStream")
super.write(b, off, len)
}

override def reset(): Unit = {
require(!closed, "cannot reset a closed ByteBufferOutputStream")
super.reset()
}

override def close(): Unit = {
if (!closed) {
super.close()
closed = true
}
}

def toByteBuffer: ByteBuffer = {
return ByteBuffer.wrap(buf, 0, count)
require(closed, "can only call toByteBuffer() after ByteBufferOutputStream has been closed")
ByteBuffer.wrap(buf, 0, count)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,27 @@ private[spark] class ChunkedByteBufferOutputStream(
*/
private[this] var position = chunkSize
private[this] var _size = 0
private[this] var closed: Boolean = false

def size: Long = _size

override def close(): Unit = {
if (!closed) {
super.close()
closed = true
}
}

override def write(b: Int): Unit = {
require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
allocateNewChunkIfNeeded()
chunks(lastChunkIndex).put(b.toByte)
position += 1
_size += 1
}

override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
var written = 0
while (written < len) {
allocateNewChunkIfNeeded()
Expand All @@ -73,7 +83,6 @@ private[spark] class ChunkedByteBufferOutputStream(

@inline
private def allocateNewChunkIfNeeded(): Unit = {
require(!toChunkedByteBufferWasCalled, "cannot write after toChunkedByteBuffer() is called")
if (position == chunkSize) {
chunks += allocator(chunkSize)
lastChunkIndex += 1
Expand All @@ -82,6 +91,7 @@ private[spark] class ChunkedByteBufferOutputStream(
}

def toChunkedByteBuffer: ChunkedByteBuffer = {
require(closed, "cannot call toChunkedByteBuffer() unless close() has been called")
require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be called once")
toChunkedByteBufferWasCalled = true
if (lastChunkIndex == -1) {
Expand Down
34 changes: 16 additions & 18 deletions core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ class MemoryStoreSuite
(memoryStore, blockInfoManager)
}

private def assertSameContents[T](expected: Seq[T], actual: Seq[T], hint: String): Unit = {
assert(actual.length === expected.length, s"wrong number of values returned in $hint")
expected.iterator.zip(actual.iterator).foreach { case (e, a) =>
assert(e === a, s"$hint did not return original values!")
}
}

test("reserve/release unroll memory") {
val (memoryStore, _) = makeMemoryStore(12000)
assert(memoryStore.currentUnrollMemory === 0)
Expand Down Expand Up @@ -137,9 +144,7 @@ class MemoryStoreSuite
var putResult = putIteratorAsValues("unroll", smallList.iterator, ClassTag.Any)
assert(putResult.isRight)
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
assert(e === a, "getValues() did not return original values!")
}
assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues")
blockInfoManager.lockForWriting("unroll")
assert(memoryStore.remove("unroll"))
blockInfoManager.removeBlock("unroll")
Expand All @@ -152,9 +157,7 @@ class MemoryStoreSuite
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
assert(memoryStore.contains("someBlock2"))
assert(!memoryStore.contains("someBlock1"))
smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
assert(e === a, "getValues() did not return original values!")
}
assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues")
blockInfoManager.lockForWriting("unroll")
assert(memoryStore.remove("unroll"))
blockInfoManager.removeBlock("unroll")
Expand All @@ -167,9 +170,7 @@ class MemoryStoreSuite
assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
assert(!memoryStore.contains("someBlock2"))
assert(putResult.isLeft)
bigList.iterator.zip(putResult.left.get).foreach { case (e, a) =>
assert(e === a, "putIterator() did not return original values!")
}
assertSameContents(bigList, putResult.left.get.toSeq, "putIterator")
// The unroll memory was freed once the iterator returned by putIterator() was fully traversed.
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
}
Expand Down Expand Up @@ -316,9 +317,8 @@ class MemoryStoreSuite
assert(res.isLeft)
assert(memoryStore.currentUnrollMemoryForThisTask > 0)
val valuesReturnedFromFailedPut = res.left.get.valuesIterator.toSeq // force materialization
valuesReturnedFromFailedPut.zip(bigList).foreach { case (e, a) =>
assert(e === a, "PartiallySerializedBlock.valuesIterator() did not return original values!")
}
assertSameContents(
bigList, valuesReturnedFromFailedPut, "PartiallySerializedBlock.valuesIterator()")
// The unroll memory was freed once the iterator was fully traversed.
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
}
Expand All @@ -340,12 +340,10 @@ class MemoryStoreSuite
res.left.get.finishWritingToStream(bos)
// The unroll memory was freed once the block was fully written.
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
val deserializationStream = serializerManager.dataDeserializeStream[Any](
"b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any)
deserializationStream.zip(bigList.iterator).foreach { case (e, a) =>
assert(e === a,
"PartiallySerializedBlock.finishWritingtoStream() did not write original values!")
}
val deserializedValues = serializerManager.dataDeserializeStream[Any](
"b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any).toSeq
assertSameContents(
bigList, deserializedValues, "PartiallySerializedBlock.finishWritingToStream()")
}

test("multiple unrolls by the same thread") {
Expand Down
Loading