Skip to content

Commit 2d6a5fb

Browse files
committed
Use putBytes/getRemoteBytes throughout.
1 parent 3670f00 commit 2d6a5fb

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import scala.util.Random
2727
import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
2828
import org.apache.spark.io.CompressionCodec
2929
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
30+
import org.apache.spark.util.ByteBufferInputStream
3031

3132
/**
3233
* A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]].
@@ -76,8 +77,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](
7677
private def writeBlocks(): Int = {
7778
val blocks = TorrentBroadcast.blockifyObject(_value)
7879
blocks.zipWithIndex.foreach { case (block, i) =>
79-
// TODO: Use putBytes directly.
80-
SparkEnv.get.blockManager.putSingle(
80+
SparkEnv.get.blockManager.putBytes(
8181
BroadcastBlockId(id, "piece" + i),
8282
block,
8383
StorageLevel.MEMORY_AND_DISK_SER,
@@ -87,21 +87,21 @@ private[spark] class TorrentBroadcast[T: ClassTag](
8787
}
8888

8989
/** Fetch torrent blocks from the driver and/or other executors. */
90-
private def readBlocks(): Array[Array[Byte]] = {
90+
private def readBlocks(): Array[ByteBuffer] = {
9191
// Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
9292
// to the driver, so other executors can pull these chunks from this executor as well.
9393
var numBlocksAvailable = 0
94-
val blocks = new Array[Array[Byte]](numBlocks)
94+
val blocks = new Array[ByteBuffer](numBlocks)
9595

9696
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
9797
val pieceId = BroadcastBlockId(id, "piece" + pid)
98-
SparkEnv.get.blockManager.getSingle(pieceId) match {
98+
SparkEnv.get.blockManager.getRemoteBytes(pieceId) match {
9999
case Some(x) =>
100-
blocks(pid) = x.asInstanceOf[Array[Byte]]
100+
blocks(pid) = x.asInstanceOf[ByteBuffer]
101101
numBlocksAvailable += 1
102102
SparkEnv.get.blockManager.putBytes(
103103
pieceId,
104-
ByteBuffer.wrap(blocks(pid)),
104+
blocks(pid),
105105
StorageLevel.MEMORY_AND_DISK_SER,
106106
tellMaster = true)
107107

@@ -182,7 +182,7 @@ private object TorrentBroadcast extends Logging {
182182
initialized = false
183183
}
184184

185-
def blockifyObject[T: ClassTag](obj: T): Array[Array[Byte]] = {
185+
def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = {
186186
// TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks
187187
// so we don't need to do the extra memory copy.
188188
val bos = new ByteArrayOutputStream()
@@ -193,24 +193,24 @@ private object TorrentBroadcast extends Logging {
193193
val byteArray = bos.toByteArray
194194
val bais = new ByteArrayInputStream(byteArray)
195195
val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE).toInt
196-
val blocks = new Array[Array[Byte]](numBlocks)
196+
val blocks = new Array[ByteBuffer](numBlocks)
197197

198198
var blockId = 0
199199
for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
200200
val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
201201
val tempByteArray = new Array[Byte](thisBlockSize)
202202
bais.read(tempByteArray, 0, thisBlockSize)
203203

204-
blocks(blockId) = tempByteArray
204+
blocks(blockId) = ByteBuffer.wrap(tempByteArray)
205205
blockId += 1
206206
}
207207
bais.close()
208208
blocks
209209
}
210210

211-
def unBlockifyObject[T: ClassTag](blocks: Array[Array[Byte]]): T = {
211+
def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = {
212212
val is = new SequenceInputStream(
213-
asJavaEnumeration(blocks.iterator.map(block => new ByteArrayInputStream(block))))
213+
asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block))))
214214
val in: InputStream = if (compress) compressionCodec.compressedInputStream(is) else is
215215

216216
val ser = SparkEnv.get.serializer.newInstance()

0 commit comments

Comments
 (0)