Skip to content

[SPARK-13990] Automatically pick serializer when caching RDDs #11801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ object SparkEnv extends Logging {

// NB: blockManager is not valid until initialize() is called later.
val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster,
serializer, conf, memoryManager, mapOutputTracker, shuffleManager,
serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager,
blockTransferService, securityManager, numUsableCores)

val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.network

import scala.reflect.ClassTag

import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.storage.{BlockId, StorageLevel}

Expand All @@ -35,7 +37,11 @@ trait BlockDataManager {
* Returns true if the block was stored and false if the put operation failed or the block
* already existed.
*/
def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Boolean
def putBlockData(
blockId: BlockId,
data: ManagedBuffer,
level: StorageLevel,
classTag: ClassTag[_]): Boolean

/**
* Release locks acquired by [[putBlockData()]] and [[getBlockData()]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.nio.ByteBuffer

import scala.concurrent.{Await, Future, Promise}
import scala.concurrent.duration.Duration
import scala.reflect.ClassTag

import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
Expand Down Expand Up @@ -76,7 +77,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit]
level: StorageLevel,
classTag: ClassTag[_]): Future[Unit]

/**
* A special case of [[fetchBlocks]], as it fetches only one block and is blocking.
Expand Down Expand Up @@ -114,7 +116,9 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Unit = {
Await.result(uploadBlock(hostname, port, execId, blockId, blockData, level), Duration.Inf)
level: StorageLevel,
classTag: ClassTag[_]): Unit = {
val future = uploadBlock(hostname, port, execId, blockId, blockData, level, classTag)
Await.result(future, Duration.Inf)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.network.netty
import java.nio.ByteBuffer

import scala.collection.JavaConverters._
import scala.language.existentials
import scala.reflect.ClassTag

import org.apache.spark.internal.Logging
import org.apache.spark.network.BlockDataManager
Expand Down Expand Up @@ -61,12 +63,16 @@ class NettyBlockRpcServer(
responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer)

case uploadBlock: UploadBlock =>
// StorageLevel is serialized as bytes using our JavaSerializer.
val level: StorageLevel =
serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))
// StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
val (level: StorageLevel, classTag: ClassTag[_]) = {
serializer
.newInstance()
.deserialize(ByteBuffer.wrap(uploadBlock.metadata))
.asInstanceOf[(StorageLevel, ClassTag[_])]
}
val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
val blockId = BlockId(uploadBlock.blockId)
blockManager.putBlockData(blockId, data, level)
blockManager.putBlockData(blockId, data, level, classTag)
responseContext.onSuccess(ByteBuffer.allocate(0))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.nio.ByteBuffer

import scala.collection.JavaConverters._
import scala.concurrent.{Future, Promise}
import scala.reflect.ClassTag

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.network._
Expand Down Expand Up @@ -118,18 +119,19 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit] = {
level: StorageLevel,
classTag: ClassTag[_]): Future[Unit] = {
val result = Promise[Unit]()
val client = clientFactory.createClient(hostname, port)

// StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded
// using our binary protocol.
val levelBytes = JavaUtils.bufferToArray(serializer.newInstance().serialize(level))
// StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
// Everything else is encoded using our binary protocol.
val metadata = JavaUtils.bufferToArray(serializer.newInstance().serialize((level, classTag)))

// Convert or copy nio buffer into array in order to serialize it.
val array = JavaUtils.bufferToArray(blockData.nioByteBuffer())

client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer,
client.sendRpc(new UploadBlock(appId, execId, blockId.toString, metadata, array).toByteBuffer,
new RpcResponseCallback {
override def onSuccess(response: ByteBuffer): Unit = {
logTrace(s"Successfully uploaded block $blockId")
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ abstract class RDD[T: ClassTag](
val blockId = RDDBlockId(id, partition.index)
var readCachedBlock = true
// This method is called on executors, so we need call SparkEnv.get instead of sc.env.
SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, () => {
SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
readCachedBlock = false
computeOrReadCheckpoint(partition, context)
}) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag

import com.google.common.collect.ConcurrentHashMultiset

Expand All @@ -37,10 +38,14 @@ import org.apache.spark.internal.Logging
* @param level the block's storage level. This is the requested persistence level, not the
* effective storage level of the block (i.e. if this is MEMORY_AND_DISK, then this
* does not imply that the block is actually resident in memory).
* @param classTag the block's [[ClassTag]], used to select the serializer
* @param tellMaster whether state changes for this block should be reported to the master. This
* is true for most blocks, but is false for broadcast blocks.
*/
private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
private[storage] class BlockInfo(
val level: StorageLevel,
val classTag: ClassTag[_],
val tellMaster: Boolean) {

/**
* The size of the block (in bytes)
Expand Down
Loading