Skip to content

[SPARK-17110] Fix StreamCorruptionException in BlockManager.getRemoteValues() #14952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo
assertValid()
val blockManager = SparkEnv.get.blockManager
val blockId = split.asInstanceOf[BlockRDDPartition].blockId
blockManager.get(blockId) match {
blockManager.get[T](blockId) match {
case Some(block) => block.data.asInstanceOf[Iterator[T]]
case None =>
throw new Exception("Could not compute split, block " + blockId + " not found")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,12 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
* Deserializes an InputStream into an iterator of values and disposes of it when the end of
* the iterator is reached.
*/
def dataDeserializeStream[T: ClassTag](
def dataDeserializeStream[T](
blockId: BlockId,
inputStream: InputStream): Iterator[T] = {
inputStream: InputStream)
(classTag: ClassTag[T]): Iterator[T] = {
val stream = new BufferedInputStream(inputStream)
getSerializer(implicitly[ClassTag[T]])
getSerializer(classTag)
.newInstance()
.deserializeStream(wrapStream(blockId, stream))
.asIterator.asInstanceOf[Iterator[T]]
Expand Down
15 changes: 8 additions & 7 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -520,10 +520,11 @@ private[spark] class BlockManager(
*
* This does not acquire a lock on this block in this JVM.
*/
private def getRemoteValues(blockId: BlockId): Option[BlockResult] = {
private def getRemoteValues[T: ClassTag](blockId: BlockId): Option[BlockResult] = {
val ct = implicitly[ClassTag[T]]
getRemoteBytes(blockId).map { data =>
val values =
serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))
serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))(ct)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for dataDeserializeStream to require a classtag to be explicitly passed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not saying this should definitely be done one way or the other, but I'm curious why you have a preference for the extra code and more verbose API that come with making the classTag an explicit parameter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it is easy to accidentally forget to pass a correct classtag, since this has happened twice already.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you forget to pass a correct ClassTag when the compiler is enforcing its presence via the context bound?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, the problem is that the type parameter was inferred as Any.

new BlockResult(values, DataReadMethod.Network, data.size)
}
}
Expand Down Expand Up @@ -602,13 +603,13 @@ private[spark] class BlockManager(
* any locks if the block was fetched from a remote block manager. The read lock will
* automatically be freed once the result's `data` iterator is fully consumed.
*/
def get(blockId: BlockId): Option[BlockResult] = {
def get[T: ClassTag](blockId: BlockId): Option[BlockResult] = {
val local = getLocalValues(blockId)
if (local.isDefined) {
logInfo(s"Found block $blockId locally")
return local
}
val remote = getRemoteValues(blockId)
val remote = getRemoteValues[T](blockId)
if (remote.isDefined) {
logInfo(s"Found block $blockId remotely")
return remote
Expand Down Expand Up @@ -660,7 +661,7 @@ private[spark] class BlockManager(
makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = {
// Attempt to read the block from local or remote storage. If it's present, then we don't need
// to go through the local-get-or-put path.
get(blockId) match {
get[T](blockId)(classTag) match {
case Some(block) =>
return Left(block)
case _ =>
Expand Down Expand Up @@ -1204,8 +1205,8 @@ private[spark] class BlockManager(
/**
* Read a block consisting of a single object.
*/
def getSingle(blockId: BlockId): Option[Any] = {
get(blockId).map(_.data.next())
def getSingle[T: ClassTag](blockId: BlockId): Option[T] = {
get[T](blockId).map(_.data.next().asInstanceOf[T])
}

/**
Expand Down
6 changes: 4 additions & 2 deletions core/src/test/scala/org/apache/spark/DistributedSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,12 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
blockManager.master.getLocations(blockId).foreach { cmId =>
val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId,
blockId.toString)
val deserialized = serializerManager.dataDeserializeStream[Int](blockId,
new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream()).toList
val deserialized = serializerManager.dataDeserializeStream(blockId,
new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList
assert(deserialized === (1 to 100).toList)
}
// This will exercise the getRemoteBytes / getRemoteValues code paths:
assert(blockIds.flatMap(id => blockManager.get[Int](id).get.data).toSet === (1 to 1000).toSet)
}

Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
val blockId = partition.blockId

def getBlockFromBlockManager(): Option[Iterator[T]] = {
blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[T]])
blockManager.get[T](blockId).map(_.data.asInstanceOf[Iterator[T]])
}

def getBlockFromWriteAheadLog(): Iterator[T] = {
Expand Down Expand Up @@ -163,7 +163,8 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
dataRead.rewind()
}
serializerManager
.dataDeserializeStream(blockId, new ChunkedByteBuffer(dataRead).toInputStream())
.dataDeserializeStream(
blockId, new ChunkedByteBuffer(dataRead).toInputStream())(elementClassTag)
.asInstanceOf[Iterator[T]]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.reflect.ClassTag

import org.apache.hadoop.conf.Configuration
import org.scalatest.{BeforeAndAfter, Matchers}
Expand Down Expand Up @@ -163,7 +164,7 @@ class ReceivedBlockHandlerSuite
val bytes = reader.read(fileSegment)
reader.close()
serializerManager.dataDeserializeStream(
generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream()).toList
generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream())(ClassTag.Any).toList
}
loggedData shouldEqual data
}
Expand Down