@@ -27,6 +27,7 @@ import scala.util.Random
27
27
import org .apache .spark .{Logging , SparkConf , SparkEnv , SparkException }
28
28
import org .apache .spark .io .CompressionCodec
29
29
import org .apache .spark .storage .{BroadcastBlockId , StorageLevel }
30
+ import org .apache .spark .util .ByteBufferInputStream
30
31
31
32
/**
32
33
* A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast ]].
@@ -76,8 +77,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](
76
77
private def writeBlocks (): Int = {
77
78
val blocks = TorrentBroadcast .blockifyObject(_value)
78
79
blocks.zipWithIndex.foreach { case (block, i) =>
79
- // TODO: Use putBytes directly.
80
- SparkEnv .get.blockManager.putSingle(
80
+ SparkEnv .get.blockManager.putBytes(
81
81
BroadcastBlockId (id, " piece" + i),
82
82
block,
83
83
StorageLevel .MEMORY_AND_DISK_SER ,
@@ -87,21 +87,21 @@ private[spark] class TorrentBroadcast[T: ClassTag](
87
87
}
88
88
89
89
/** Fetch torrent blocks from the driver and/or other executors. */
90
- private def readBlocks (): Array [Array [ Byte ] ] = {
90
+ private def readBlocks (): Array [ByteBuffer ] = {
91
91
// Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
92
92
// to the driver, so other executors can pull these chunks from this executor as well.
93
93
var numBlocksAvailable = 0
94
- val blocks = new Array [Array [ Byte ] ](numBlocks)
94
+ val blocks = new Array [ByteBuffer ](numBlocks)
95
95
96
96
for (pid <- Random .shuffle(Seq .range(0 , numBlocks))) {
97
97
val pieceId = BroadcastBlockId (id, " piece" + pid)
98
- SparkEnv .get.blockManager.getSingle (pieceId) match {
98
+ SparkEnv .get.blockManager.getRemoteBytes (pieceId) match {
99
99
case Some (x) =>
100
- blocks(pid) = x.asInstanceOf [Array [ Byte ] ]
100
+ blocks(pid) = x.asInstanceOf [ByteBuffer ]
101
101
numBlocksAvailable += 1
102
102
SparkEnv .get.blockManager.putBytes(
103
103
pieceId,
104
- ByteBuffer .wrap( blocks(pid) ),
104
+ blocks(pid),
105
105
StorageLevel .MEMORY_AND_DISK_SER ,
106
106
tellMaster = true )
107
107
@@ -182,7 +182,7 @@ private object TorrentBroadcast extends Logging {
182
182
initialized = false
183
183
}
184
184
185
- def blockifyObject [T : ClassTag ](obj : T ): Array [Array [ Byte ] ] = {
185
+ def blockifyObject [T : ClassTag ](obj : T ): Array [ByteBuffer ] = {
186
186
// TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks
187
187
// so we don't need to do the extra memory copy.
188
188
val bos = new ByteArrayOutputStream ()
@@ -193,24 +193,24 @@ private object TorrentBroadcast extends Logging {
193
193
val byteArray = bos.toByteArray
194
194
val bais = new ByteArrayInputStream (byteArray)
195
195
val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE ).toInt
196
- val blocks = new Array [Array [ Byte ] ](numBlocks)
196
+ val blocks = new Array [ByteBuffer ](numBlocks)
197
197
198
198
var blockId = 0
199
199
for (i <- 0 until (byteArray.length, BLOCK_SIZE )) {
200
200
val thisBlockSize = math.min(BLOCK_SIZE , byteArray.length - i)
201
201
val tempByteArray = new Array [Byte ](thisBlockSize)
202
202
bais.read(tempByteArray, 0 , thisBlockSize)
203
203
204
- blocks(blockId) = tempByteArray
204
+ blocks(blockId) = ByteBuffer .wrap( tempByteArray)
205
205
blockId += 1
206
206
}
207
207
bais.close()
208
208
blocks
209
209
}
210
210
211
- def unBlockifyObject [T : ClassTag ](blocks : Array [Array [ Byte ] ]): T = {
211
+ def unBlockifyObject [T : ClassTag ](blocks : Array [ByteBuffer ]): T = {
212
212
val is = new SequenceInputStream (
213
- asJavaEnumeration(blocks.iterator.map(block => new ByteArrayInputStream (block))))
213
+ asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream (block))))
214
214
val in : InputStream = if (compress) compressionCodec.compressedInputStream(is) else is
215
215
216
216
val ser = SparkEnv .get.serializer.newInstance()
0 commit comments