18
18
package org .apache .spark .broadcast
19
19
20
20
import java .io ._
21
+ import java .nio .ByteBuffer
21
22
23
+ import scala .collection .JavaConversions .asJavaEnumeration
22
24
import scala .reflect .ClassTag
23
25
import scala .util .Random
24
26
@@ -27,41 +29,87 @@ import org.apache.spark.io.CompressionCodec
27
29
import org .apache .spark .storage .{BroadcastBlockId , StorageLevel }
28
30
29
31
/**
30
- * A [[org.apache.spark.broadcast.Broadcast ]] implementation that uses a BitTorrent-like
31
- * protocol to do a distributed transfer of the broadcasted data to the executors.
32
- * The mechanism is as follows. The driver divides the serializes the broadcasted data,
33
- * divides it into smaller chunks, and stores them in the BlockManager of the driver.
34
- * These chunks are reported to the BlockManagerMaster so that all the executors can
35
- * learn the location of those chunks. The first time the broadcast variable (sent as
36
- * part of task) is deserialized at a executor, all the chunks are fetched using
37
- * the BlockManager. When all the chunks are fetched (initially from the driver's
38
- * BlockManager), they are combined and deserialized to recreate the broadcasted data.
39
- * However, the chunks are also stored in the BlockManager and reported to the
40
- * BlockManagerMaster. As more executors fetch the chunks, BlockManagerMaster learns
41
- * multiple locations for each chunk. Hence, subsequent fetches of each chunk will be
42
- * made to other executors who already have those chunks, resulting in a distributed
43
- * fetching. This prevents the driver from being the bottleneck in sending out multiple
44
- * copies of the broadcast data (one per executor) as done by the
45
- * [[org.apache.spark.broadcast.HttpBroadcast ]].
32
+ * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast ]].
33
+ *
34
+ * The mechanism is as follows:
35
+ *
36
+ * The driver divides the serialized object into small chunks and
37
+ * stores those chunks in the BlockManager of the driver.
38
+ *
39
+ * On each executor, the executor first attempts to fetch the object from its BlockManager. If
40
+ * it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or
41
+ * other executors if available. Once it gets the chunks, it puts the chunks in its own
42
+ * BlockManager, ready for other executors to fetch from.
43
+ *
44
+ * This prevents the driver from being the bottleneck in sending out multiple copies of the
45
+ * broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast ]].
46
+ *
47
+ * @param obj object to broadcast
48
+ * @param isLocal whether Spark is running in local mode (single JVM process).
49
+ * @param id A unique identifier for the broadcast variable.
46
50
*/
47
51
private [spark] class TorrentBroadcast [T : ClassTag ](
48
- @ transient var value_ : T , isLocal : Boolean , id : Long )
52
+ obj : T ,
53
+ @ transient private val isLocal : Boolean ,
54
+ id : Long )
49
55
extends Broadcast [T ](id) with Logging with Serializable {
50
56
51
- override protected def getValue () = value_
57
+ override protected def getValue () = _value
58
+
59
+ /**
60
+ * Value of the broadcast object. On driver, this is set directly by the constructor.
61
+ * On executors, this is reconstructed by [[readObject ]], which builds this value by reading
62
+ * blocks from the driver and/or other executors.
63
+ */
64
+ @ transient private var _value : T = obj
65
+
66
+ /** Total number of blocks this broadcast variable contains. */
67
+ private val numBlocks : Int = writeBlocks()
52
68
53
69
private val broadcastId = BroadcastBlockId (id)
54
70
55
- SparkEnv .get.blockManager.putSingle(
56
- broadcastId, value_, StorageLevel .MEMORY_AND_DISK , tellMaster = false )
71
+ /**
72
+ * Divide the object into multiple blocks and put those blocks in the block manager.
73
+ *
74
+ * @return number of blocks this broadcast variable is divided into
75
+ */
76
+ private def writeBlocks (): Int = {
77
+ val blocks = TorrentBroadcast .blockifyObject(_value)
78
+ blocks.zipWithIndex.foreach { case (block, i) =>
79
+ // TODO: Use putBytes directly.
80
+ SparkEnv .get.blockManager.putSingle(
81
+ BroadcastBlockId (id, " piece" + i),
82
+ blocks(i),
83
+ StorageLevel .MEMORY_AND_DISK_SER ,
84
+ tellMaster = true )
85
+ }
86
+ blocks.length
87
+ }
57
88
58
- @ transient private var arrayOfBlocks : Array [TorrentBlock ] = null
59
- @ transient private var totalBlocks = - 1
60
- @ transient private var totalBytes = - 1
61
- @ transient private var hasBlocks = 0
89
+ /** Fetch torrent blocks from the driver and/or other executors. */
90
+ private def readBlocks (): Array [Array [Byte ]] = {
91
+ // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
92
+ // to the driver, so other executors can pull these thunks from this executor as well.
93
+ var numBlocksAvailable = 0
94
+ val blocks = new Array [Array [Byte ]](numBlocks)
62
95
63
- if (! isLocal) {
64
- sendBroadcast()
96
+ for (pid <- Random .shuffle(Seq .range(0 , numBlocks))) {
97
+ val pieceId = BroadcastBlockId (id, " piece" + pid)
98
+ SparkEnv .get.blockManager.getSingle(pieceId) match {
99
+ case Some (x) =>
100
+ blocks(pid) = x.asInstanceOf [Array [Byte ]]
101
+ numBlocksAvailable += 1
102
+ SparkEnv .get.blockManager.putBytes(
103
+ pieceId,
104
+ ByteBuffer .wrap(blocks(pid)),
105
+ StorageLevel .MEMORY_AND_DISK_SER ,
106
+ tellMaster = true )
107
+
108
+ case None =>
109
+ throw new SparkException (" Failed to get " + pieceId + " of " + broadcastId)
110
+ }
111
+ }
112
+ blocks
65
113
}
66
114
67
115
/**
@@ -79,26 +127,6 @@ private[spark] class TorrentBroadcast[T: ClassTag](
79
127
TorrentBroadcast .unpersist(id, removeFromDriver = true , blocking)
80
128
}
81
129
82
- private def sendBroadcast () {
83
- val tInfo = TorrentBroadcast .blockifyObject(value_)
84
- totalBlocks = tInfo.totalBlocks
85
- totalBytes = tInfo.totalBytes
86
- hasBlocks = tInfo.totalBlocks
87
-
88
- // Store meta-info
89
- val metaId = BroadcastBlockId (id, " meta" )
90
- val metaInfo = TorrentInfo (null , totalBlocks, totalBytes)
91
- SparkEnv .get.blockManager.putSingle(
92
- metaId, metaInfo, StorageLevel .MEMORY_AND_DISK , tellMaster = true )
93
-
94
- // Store individual pieces
95
- for (i <- 0 until totalBlocks) {
96
- val pieceId = BroadcastBlockId (id, " piece" + i)
97
- SparkEnv .get.blockManager.putSingle(
98
- pieceId, tInfo.arrayOfBlocks(i), StorageLevel .MEMORY_AND_DISK , tellMaster = true )
99
- }
100
- }
101
-
102
130
/** Used by the JVM when serializing this object. */
103
131
private def writeObject (out : ObjectOutputStream ) {
104
132
assertValid()
@@ -109,99 +137,30 @@ private[spark] class TorrentBroadcast[T: ClassTag](
109
137
private def readObject (in : ObjectInputStream ) {
110
138
in.defaultReadObject()
111
139
TorrentBroadcast .synchronized {
112
- SparkEnv .get.blockManager.getSingle (broadcastId) match {
140
+ SparkEnv .get.blockManager.getLocal (broadcastId).map(_.data.next() ) match {
113
141
case Some (x) =>
114
- value_ = x.asInstanceOf [T ]
142
+ _value = x.asInstanceOf [T ]
115
143
116
144
case None =>
117
- val start = System .nanoTime
118
145
logInfo(" Started reading broadcast variable " + id)
119
-
120
- // Initialize @transient variables that will receive garbage values from the master.
121
- resetWorkerVariables()
122
-
123
- if (receiveBroadcast()) {
124
- value_ = TorrentBroadcast .unBlockifyObject[T ](arrayOfBlocks, totalBytes, totalBlocks)
125
-
126
- /* Store the merged copy in cache so that the next worker doesn't need to rebuild it.
127
- * This creates a trade-off between memory usage and latency. Storing copy doubles
128
- * the memory footprint; not storing doubles deserialization cost. Also,
129
- * this does not need to be reported to BlockManagerMaster since other executors
130
- * does not need to access this block (they only need to fetch the chunks,
131
- * which are reported).
132
- */
133
- SparkEnv .get.blockManager.putSingle(
134
- broadcastId, value_, StorageLevel .MEMORY_AND_DISK , tellMaster = false )
135
-
136
- // Remove arrayOfBlocks from memory once value_ is on local cache
137
- resetWorkerVariables()
138
- } else {
139
- logError(" Reading broadcast variable " + id + " failed" )
140
- }
141
-
142
- val time = (System .nanoTime - start) / 1e9
146
+ val start = System .nanoTime()
147
+ val blocks = readBlocks()
148
+ val time = (System .nanoTime() - start) / 1e9
143
149
logInfo(" Reading broadcast variable " + id + " took " + time + " s" )
144
- }
145
- }
146
- }
147
-
148
- private def resetWorkerVariables () {
149
- arrayOfBlocks = null
150
- totalBytes = - 1
151
- totalBlocks = - 1
152
- hasBlocks = 0
153
- }
154
-
155
- private def receiveBroadcast (): Boolean = {
156
- // Receive meta-info about the size of broadcast data,
157
- // the number of chunks it is divided into, etc.
158
- val metaId = BroadcastBlockId (id, " meta" )
159
- var attemptId = 10
160
- while (attemptId > 0 && totalBlocks == - 1 ) {
161
- SparkEnv .get.blockManager.getSingle(metaId) match {
162
- case Some (x) =>
163
- val tInfo = x.asInstanceOf [TorrentInfo ]
164
- totalBlocks = tInfo.totalBlocks
165
- totalBytes = tInfo.totalBytes
166
- arrayOfBlocks = new Array [TorrentBlock ](totalBlocks)
167
- hasBlocks = 0
168
-
169
- case None =>
170
- Thread .sleep(500 )
171
- }
172
- attemptId -= 1
173
- }
174
-
175
- if (totalBlocks == - 1 ) {
176
- return false
177
- }
178
150
179
- /*
180
- * Fetch actual chunks of data. Note that all these chunks are stored in
181
- * the BlockManager and reported to the master, so that other executors
182
- * can find out and pull the chunks from this executor.
183
- */
184
- val recvOrder = new Random ().shuffle(Array .iterate(0 , totalBlocks)(_ + 1 ).toList)
185
- for (pid <- recvOrder) {
186
- val pieceId = BroadcastBlockId (id, " piece" + pid)
187
- SparkEnv .get.blockManager.getSingle(pieceId) match {
188
- case Some (x) =>
189
- arrayOfBlocks(pid) = x.asInstanceOf [TorrentBlock ]
190
- hasBlocks += 1
151
+ _value = TorrentBroadcast .unBlockifyObject[T ](blocks)
152
+ // Store the merged copy in BlockManager so other tasks on this executor doesn't
153
+ // need to re-fetch it.
191
154
SparkEnv .get.blockManager.putSingle(
192
- pieceId, arrayOfBlocks(pid), StorageLevel .MEMORY_AND_DISK , tellMaster = true )
193
-
194
- case None =>
195
- throw new SparkException (" Failed to get " + pieceId + " of " + broadcastId)
155
+ broadcastId, _value, StorageLevel .MEMORY_AND_DISK , tellMaster = false )
196
156
}
197
157
}
198
-
199
- hasBlocks == totalBlocks
200
158
}
201
-
202
159
}
203
160
204
- private [broadcast] object TorrentBroadcast extends Logging {
161
+
162
+ private object TorrentBroadcast extends Logging {
163
+ /** Size of each block. Default value is 4MB. */
205
164
private lazy val BLOCK_SIZE = conf.getInt(" spark.broadcast.blockSize" , 4096 ) * 1024
206
165
private var initialized = false
207
166
private var conf : SparkConf = null
@@ -223,52 +182,37 @@ private[broadcast] object TorrentBroadcast extends Logging {
223
182
initialized = false
224
183
}
225
184
226
- def blockifyObject [T : ClassTag ](obj : T ): TorrentInfo = {
185
+ def blockifyObject [T : ClassTag ](obj : T ): Array [Array [Byte ]] = {
186
+ // TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks
187
+ // so we don't need to do the extra memory copy.
227
188
val bos = new ByteArrayOutputStream ()
228
189
val out : OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos
229
190
val ser = SparkEnv .get.serializer.newInstance()
230
191
val serOut = ser.serializeStream(out)
231
192
serOut.writeObject[T ](obj).close()
232
193
val byteArray = bos.toByteArray
233
194
val bais = new ByteArrayInputStream (byteArray)
195
+ val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE ).toInt
196
+ val blocks = new Array [Array [Byte ]](numBlocks)
234
197
235
- var blockNum = byteArray.length / BLOCK_SIZE
236
- if (byteArray.length % BLOCK_SIZE != 0 ) {
237
- blockNum += 1
238
- }
239
-
240
- val blocks = new Array [TorrentBlock ](blockNum)
241
198
var blockId = 0
242
-
243
199
for (i <- 0 until (byteArray.length, BLOCK_SIZE )) {
244
200
val thisBlockSize = math.min(BLOCK_SIZE , byteArray.length - i)
245
201
val tempByteArray = new Array [Byte ](thisBlockSize)
246
202
bais.read(tempByteArray, 0 , thisBlockSize)
247
203
248
- blocks(blockId) = new TorrentBlock (blockId, tempByteArray)
204
+ blocks(blockId) = tempByteArray
249
205
blockId += 1
250
206
}
251
207
bais.close()
252
-
253
- val info = TorrentInfo (blocks, blockNum, byteArray.length)
254
- info.hasBlocks = blockNum
255
- info
208
+ blocks
256
209
}
257
210
258
- def unBlockifyObject [T : ClassTag ](
259
- arrayOfBlocks : Array [TorrentBlock ],
260
- totalBytes : Int ,
261
- totalBlocks : Int ): T = {
262
- val retByteArray = new Array [Byte ](totalBytes)
263
- for (i <- 0 until totalBlocks) {
264
- System .arraycopy(arrayOfBlocks(i).byteArray, 0 , retByteArray,
265
- i * BLOCK_SIZE , arrayOfBlocks(i).byteArray.length)
266
- }
211
+ def unBlockifyObject [T : ClassTag ](blocks : Array [Array [Byte ]]): T = {
212
+ val is = new SequenceInputStream (
213
+ asJavaEnumeration(blocks.iterator.map(block => new ByteArrayInputStream (block))))
214
+ val in : InputStream = if (compress) compressionCodec.compressedInputStream(is) else is
267
215
268
- val in : InputStream = {
269
- val arrIn = new ByteArrayInputStream (retByteArray)
270
- if (compress) compressionCodec.compressedInputStream(arrIn) else arrIn
271
- }
272
216
val ser = SparkEnv .get.serializer.newInstance()
273
217
val serIn = ser.deserializeStream(in)
274
218
val obj = serIn.readObject[T ]()
@@ -284,17 +228,3 @@ private[broadcast] object TorrentBroadcast extends Logging {
284
228
SparkEnv .get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
285
229
}
286
230
}
287
-
288
- private [broadcast] case class TorrentBlock (
289
- blockID : Int ,
290
- byteArray : Array [Byte ])
291
- extends Serializable
292
-
293
- private [broadcast] case class TorrentInfo (
294
- @ transient arrayOfBlocks : Array [TorrentBlock ],
295
- totalBlocks : Int ,
296
- totalBytes : Int )
297
- extends Serializable {
298
-
299
- @ transient var hasBlocks = 0
300
- }
0 commit comments