Skip to content

[SPARK-2897][SPARK-2920]TorrentBroadcast does use the serializer class specified in the spark option "spark.serializer" #1836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

package org.apache.spark.broadcast

import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream}
import java.io.{ByteArrayOutputStream, ByteArrayInputStream, InputStream,
ObjectInputStream, ObjectOutputStream, OutputStream}

import scala.reflect.ClassTag
import scala.util.Random

import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.Utils

/**
* A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
Expand Down Expand Up @@ -214,11 +215,15 @@ private[broadcast] object TorrentBroadcast extends Logging {
private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
private var initialized = false
private var conf: SparkConf = null
private var compress: Boolean = false
private var compressionCodec: CompressionCodec = null

def initialize(_isDriver: Boolean, conf: SparkConf) {
TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests
synchronized {
if (!initialized) {
compress = conf.getBoolean("spark.broadcast.compress", true)
compressionCodec = CompressionCodec.createCodec(conf)
initialized = true
}
}
Expand All @@ -228,8 +233,13 @@ private[broadcast] object TorrentBroadcast extends Logging {
initialized = false
}

def blockifyObject[T](obj: T): TorrentInfo = {
val byteArray = Utils.serialize[T](obj)
def blockifyObject[T: ClassTag](obj: T): TorrentInfo = {
val bos = new ByteArrayOutputStream()
val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos
val ser = SparkEnv.get.serializer.newInstance()
val serOut = ser.serializeStream(out)
serOut.writeObject[T](obj).close()
val byteArray = bos.toByteArray
val bais = new ByteArrayInputStream(byteArray)

var blockNum = byteArray.length / BLOCK_SIZE
Expand All @@ -255,7 +265,7 @@ private[broadcast] object TorrentBroadcast extends Logging {
info
}

def unBlockifyObject[T](
def unBlockifyObject[T: ClassTag](
arrayOfBlocks: Array[TorrentBlock],
totalBytes: Int,
totalBlocks: Int): T = {
Expand All @@ -264,7 +274,16 @@ private[broadcast] object TorrentBroadcast extends Logging {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
}
Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)

val in: InputStream = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rewrite this as

val in: InputStream = {
  val istream = new ByteArrayInputStream(retByteArray)
  if (compress) compressionCodec.compressedInputStream(istream) else istream
}

val arrIn = new ByteArrayInputStream(retByteArray)
if (compress) compressionCodec.compressedInputStream(arrIn) else arrIn
}
val ser = SparkEnv.get.serializer.newInstance()
val serIn = ser.deserializeStream(in)
val obj = serIn.readObject[T]()
serIn.close()
obj
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {

test("Accessing HttpBroadcast variables in a local cluster") {
val numSlaves = 4
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf)
val conf = httpConf.clone
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
conf.set("spark.broadcast.compress", "true")
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf)
val list = List[Int](1, 2, 3, 4)
val broadcast = sc.broadcast(list)
val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
Expand All @@ -69,7 +72,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {

test("Accessing TorrentBroadcast variables in a local cluster") {
val numSlaves = 4
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf)
val conf = torrentConf.clone
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
conf.set("spark.broadcast.compress", "true")
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf)
val list = List[Int](1, 2, 3, 4)
val broadcast = sc.broadcast(list)
val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
Expand Down