Skip to content

Commit 29cfab3

Browse files
committed
[SPARK-17110] Fix StreamCorruptionException in BlockManager.getRemoteValues()
## What changes were proposed in this pull request? This patch fixes a `java.io.StreamCorruptedException` error affecting remote reads of cached values when certain data types are used. The problem stems from #11801 / SPARK-13990, a patch to have Spark automatically pick the "best" serializer when caching RDDs. If PySpark cached a PythonRDD, then this would be cached as an `RDD[Array[Byte]]` and the automatic serializer selection would pick KryoSerializer for replication and block transfer. However, the `getRemoteValues()` / `getRemoteBytes()` code path did not pass proper class tags in order to enable the same serializer to be used during deserialization, causing Java to be inappropriately used instead of Kryo, leading to the StreamCorruptedException. We already fixed a similar bug in #14311, which dealt with similar issues in block replication. Prior to that patch, it seems that we had no tests to ensure that block replication actually succeeded. Similarly, prior to this bug fix patch it looks like we had no tests to perform remote reads of cached data, which is why this bug was able to remain latent for so long. This patch addresses the bug by modifying `BlockManager`'s `get()` and `getRemoteValues()` methods to accept ClassTags, allowing the proper class tag to be threaded in the `getOrElseUpdate` code path (which is used by `rdd.iterator`) ## How was this patch tested? Extended the caching tests in `DistributedSuite` to exercise the `getRemoteValues` path, plus manual testing to verify that the PySpark bug reproduction in SPARK-17110 is fixed. Author: Josh Rosen <joshrosen@databricks.com> Closes #14952 from JoshRosen/SPARK-17110.
1 parent 8bbb08a commit 29cfab3

File tree

6 files changed

+22
-16
lines changed

6 files changed

+22
-16
lines changed

core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo
4444
assertValid()
4545
val blockManager = SparkEnv.get.blockManager
4646
val blockId = split.asInstanceOf[BlockRDDPartition].blockId
47-
blockManager.get(blockId) match {
47+
blockManager.get[T](blockId) match {
4848
case Some(block) => block.data.asInstanceOf[Iterator[T]]
4949
case None =>
5050
throw new Exception("Could not compute split, block " + blockId + " not found")

core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,12 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
180180
* Deserializes an InputStream into an iterator of values and disposes of it when the end of
181181
* the iterator is reached.
182182
*/
183-
def dataDeserializeStream[T: ClassTag](
183+
def dataDeserializeStream[T](
184184
blockId: BlockId,
185-
inputStream: InputStream): Iterator[T] = {
185+
inputStream: InputStream)
186+
(classTag: ClassTag[T]): Iterator[T] = {
186187
val stream = new BufferedInputStream(inputStream)
187-
getSerializer(implicitly[ClassTag[T]])
188+
getSerializer(classTag)
188189
.newInstance()
189190
.deserializeStream(wrapStream(blockId, stream))
190191
.asIterator.asInstanceOf[Iterator[T]]

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -520,10 +520,11 @@ private[spark] class BlockManager(
520520
*
521521
* This does not acquire a lock on this block in this JVM.
522522
*/
523-
private def getRemoteValues(blockId: BlockId): Option[BlockResult] = {
523+
private def getRemoteValues[T: ClassTag](blockId: BlockId): Option[BlockResult] = {
524+
val ct = implicitly[ClassTag[T]]
524525
getRemoteBytes(blockId).map { data =>
525526
val values =
526-
serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))
527+
serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))(ct)
527528
new BlockResult(values, DataReadMethod.Network, data.size)
528529
}
529530
}
@@ -602,13 +603,13 @@ private[spark] class BlockManager(
602603
* any locks if the block was fetched from a remote block manager. The read lock will
603604
* automatically be freed once the result's `data` iterator is fully consumed.
604605
*/
605-
def get(blockId: BlockId): Option[BlockResult] = {
606+
def get[T: ClassTag](blockId: BlockId): Option[BlockResult] = {
606607
val local = getLocalValues(blockId)
607608
if (local.isDefined) {
608609
logInfo(s"Found block $blockId locally")
609610
return local
610611
}
611-
val remote = getRemoteValues(blockId)
612+
val remote = getRemoteValues[T](blockId)
612613
if (remote.isDefined) {
613614
logInfo(s"Found block $blockId remotely")
614615
return remote
@@ -660,7 +661,7 @@ private[spark] class BlockManager(
660661
makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = {
661662
// Attempt to read the block from local or remote storage. If it's present, then we don't need
662663
// to go through the local-get-or-put path.
663-
get(blockId) match {
664+
get[T](blockId)(classTag) match {
664665
case Some(block) =>
665666
return Left(block)
666667
case _ =>
@@ -1204,8 +1205,8 @@ private[spark] class BlockManager(
12041205
/**
12051206
* Read a block consisting of a single object.
12061207
*/
1207-
def getSingle(blockId: BlockId): Option[Any] = {
1208-
get(blockId).map(_.data.next())
1208+
def getSingle[T: ClassTag](blockId: BlockId): Option[T] = {
1209+
get[T](blockId).map(_.data.next().asInstanceOf[T])
12091210
}
12101211

12111212
/**

core/src/test/scala/org/apache/spark/DistributedSuite.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,12 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
170170
blockManager.master.getLocations(blockId).foreach { cmId =>
171171
val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId,
172172
blockId.toString)
173-
val deserialized = serializerManager.dataDeserializeStream[Int](blockId,
174-
new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream()).toList
173+
val deserialized = serializerManager.dataDeserializeStream(blockId,
174+
new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList
175175
assert(deserialized === (1 to 100).toList)
176176
}
177+
// This will exercise the getRemoteBytes / getRemoteValues code paths:
178+
assert(blockIds.flatMap(id => blockManager.get[Int](id).get.data).toSet === (1 to 1000).toSet)
177179
}
178180

179181
Seq(

streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
120120
val blockId = partition.blockId
121121

122122
def getBlockFromBlockManager(): Option[Iterator[T]] = {
123-
blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[T]])
123+
blockManager.get[T](blockId).map(_.data.asInstanceOf[Iterator[T]])
124124
}
125125

126126
def getBlockFromWriteAheadLog(): Iterator[T] = {
@@ -163,7 +163,8 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
163163
dataRead.rewind()
164164
}
165165
serializerManager
166-
.dataDeserializeStream(blockId, new ChunkedByteBuffer(dataRead).toInputStream())
166+
.dataDeserializeStream(
167+
blockId, new ChunkedByteBuffer(dataRead).toInputStream())(elementClassTag)
167168
.asInstanceOf[Iterator[T]]
168169
}
169170

streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.nio.ByteBuffer
2323
import scala.collection.mutable.ArrayBuffer
2424
import scala.concurrent.duration._
2525
import scala.language.postfixOps
26+
import scala.reflect.ClassTag
2627

2728
import org.apache.hadoop.conf.Configuration
2829
import org.scalatest.{BeforeAndAfter, Matchers}
@@ -163,7 +164,7 @@ class ReceivedBlockHandlerSuite
163164
val bytes = reader.read(fileSegment)
164165
reader.close()
165166
serializerManager.dataDeserializeStream(
166-
generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream()).toList
167+
generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream())(ClassTag.Any).toList
167168
}
168169
loggedData shouldEqual data
169170
}

0 commit comments

Comments
 (0)