Skip to content

Commit 8e223ea

Browse files
ericlrxin
authored andcommitted
[SPARK-16550][SPARK-17042][CORE] Certain classes fail to deserialize in block manager replication
## What changes were proposed in this pull request? This is a straightforward clone of JoshRosen 's original patch. I have follow-up changes to fix block replication for repl-defined classes as well, but those appear to be flaking tests so I'm going to leave that for SPARK-17042 ## How was this patch tested? End-to-end test in ReplSuite (also more tests in DistributedSuite from the original patch). Author: Eric Liang <ekl@databricks.com> Closes #14311 from ericl/spark-16550.
1 parent 71afeee commit 8e223ea

File tree

4 files changed

+60
-58
lines changed

4 files changed

+60
-58
lines changed

core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
6868
* loaded yet. */
6969
private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf)
7070

71-
private def canUseKryo(ct: ClassTag[_]): Boolean = {
71+
def canUseKryo(ct: ClassTag[_]): Boolean = {
7272
primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag
7373
}
7474

@@ -128,8 +128,18 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
128128

129129
/** Serializes into a chunked byte buffer. */
130130
def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = {
131+
dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]])
132+
}
133+
134+
/** Serializes into a chunked byte buffer. */
135+
def dataSerializeWithExplicitClassTag(
136+
blockId: BlockId,
137+
values: Iterator[_],
138+
classTag: ClassTag[_]): ChunkedByteBuffer = {
131139
val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate)
132-
dataSerializeStream(blockId, bbos, values)
140+
val byteStream = new BufferedOutputStream(bbos)
141+
val ser = getSerializer(classTag).newInstance()
142+
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
133143
bbos.toChunkedByteBuffer
134144
}
135145

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,8 @@ private[spark] class BlockManager(
498498
diskStore.getBytes(blockId)
499499
} else if (level.useMemory && memoryStore.contains(blockId)) {
500500
// The block was not found on disk, so serialize an in-memory copy:
501-
serializerManager.dataSerialize(blockId, memoryStore.getValues(blockId).get)
501+
serializerManager.dataSerializeWithExplicitClassTag(
502+
blockId, memoryStore.getValues(blockId).get, info.classTag)
502503
} else {
503504
handleLocalReadFailure(blockId)
504505
}
@@ -973,8 +974,16 @@ private[spark] class BlockManager(
973974
if (level.replication > 1) {
974975
val remoteStartTime = System.currentTimeMillis
975976
val bytesToReplicate = doGetLocalBytes(blockId, info)
977+
// [SPARK-16550] Erase the typed classTag when using default serialization, since
978+
// NettyBlockRpcServer crashes when deserializing repl-defined classes.
979+
// TODO(ekl) remove this once the classloader issue on the remote end is fixed.
980+
val remoteClassTag = if (!serializerManager.canUseKryo(classTag)) {
981+
scala.reflect.classTag[Any]
982+
} else {
983+
classTag
984+
}
976985
try {
977-
replicate(blockId, bytesToReplicate, level, classTag)
986+
replicate(blockId, bytesToReplicate, level, remoteClassTag)
978987
} finally {
979988
bytesToReplicate.dispose()
980989
}

core/src/test/scala/org/apache/spark/DistributedSuite.scala

Lines changed: 23 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -149,61 +149,16 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
149149
sc.parallelize(1 to 10).count()
150150
}
151151

152-
test("caching") {
152+
private def testCaching(storageLevel: StorageLevel): Unit = {
153153
sc = new SparkContext(clusterUrl, "test")
154-
val data = sc.parallelize(1 to 1000, 10).cache()
155-
assert(data.count() === 1000)
156-
assert(data.count() === 1000)
157-
assert(data.count() === 1000)
158-
}
159-
160-
test("caching on disk") {
161-
sc = new SparkContext(clusterUrl, "test")
162-
val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY)
163-
assert(data.count() === 1000)
164-
assert(data.count() === 1000)
165-
assert(data.count() === 1000)
166-
}
167-
168-
test("caching in memory, replicated") {
169-
sc = new SparkContext(clusterUrl, "test")
170-
val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_2)
171-
assert(data.count() === 1000)
172-
assert(data.count() === 1000)
173-
assert(data.count() === 1000)
174-
}
175-
176-
test("caching in memory, serialized, replicated") {
177-
sc = new SparkContext(clusterUrl, "test")
178-
val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_SER_2)
179-
assert(data.count() === 1000)
180-
assert(data.count() === 1000)
181-
assert(data.count() === 1000)
182-
}
183-
184-
test("caching on disk, replicated") {
185-
sc = new SparkContext(clusterUrl, "test")
186-
val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY_2)
187-
assert(data.count() === 1000)
188-
assert(data.count() === 1000)
189-
assert(data.count() === 1000)
190-
}
191-
192-
test("caching in memory and disk, replicated") {
193-
sc = new SparkContext(clusterUrl, "test")
194-
val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_2)
195-
assert(data.count() === 1000)
196-
assert(data.count() === 1000)
197-
assert(data.count() === 1000)
198-
}
199-
200-
test("caching in memory and disk, serialized, replicated") {
201-
sc = new SparkContext(clusterUrl, "test")
202-
val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_SER_2)
203-
204-
assert(data.count() === 1000)
205-
assert(data.count() === 1000)
206-
assert(data.count() === 1000)
154+
sc.jobProgressListener.waitUntilExecutorsUp(2, 30000)
155+
val data = sc.parallelize(1 to 1000, 10)
156+
val cachedData = data.persist(storageLevel)
157+
assert(cachedData.count === 1000)
158+
assert(sc.getExecutorStorageStatus.map(_.rddBlocksById(cachedData.id).size).sum ===
159+
storageLevel.replication * data.getNumPartitions)
160+
assert(cachedData.count === 1000)
161+
assert(cachedData.count === 1000)
207162

208163
// Get all the locations of the first partition and try to fetch the partitions
209164
// from those locations.
@@ -221,6 +176,20 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
221176
}
222177
}
223178

179+
Seq(
180+
"caching" -> StorageLevel.MEMORY_ONLY,
181+
"caching on disk" -> StorageLevel.DISK_ONLY,
182+
"caching in memory, replicated" -> StorageLevel.MEMORY_ONLY_2,
183+
"caching in memory, serialized, replicated" -> StorageLevel.MEMORY_ONLY_SER_2,
184+
"caching on disk, replicated" -> StorageLevel.DISK_ONLY_2,
185+
"caching in memory and disk, replicated" -> StorageLevel.MEMORY_AND_DISK_2,
186+
"caching in memory and disk, serialized, replicated" -> StorageLevel.MEMORY_AND_DISK_SER_2
187+
).foreach { case (testName, storageLevel) =>
188+
test(testName) {
189+
testCaching(storageLevel)
190+
}
191+
}
192+
224193
test("compute without caching when no partitions fit in memory") {
225194
val size = 10000
226195
val conf = new SparkConf()

repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,20 @@ class ReplSuite extends SparkFunSuite {
396396
assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output)
397397
}
398398

399+
test("replicating blocks of object with class defined in repl") {
400+
val output = runInterpreter("local-cluster[2,1,1024]",
401+
"""
402+
|import org.apache.spark.storage.StorageLevel._
403+
|case class Foo(i: Int)
404+
|val ret = sc.parallelize((1 to 100).map(Foo), 10).persist(MEMORY_ONLY_2)
405+
|ret.count()
406+
|sc.getExecutorStorageStatus.map(s => s.rddBlocksById(ret.id).size).sum
407+
""".stripMargin)
408+
assertDoesNotContain("error:", output)
409+
assertDoesNotContain("Exception", output)
410+
assertContains(": Int = 20", output)
411+
}
412+
399413
test("line wrapper only initialized once when used as encoder outer scope") {
400414
val output = runInterpreter("local",
401415
"""

0 commit comments

Comments
 (0)