Skip to content

Commit c1185cd

Browse files
committed
[SPARK-3119] Re-implementation of TorrentBroadcast.
This is a re-implementation of TorrentBroadcast, with the following changes: 1. Removes most of the mutable, transient state from TorrentBroadcast (e.g. totalBytes, num of blocks fetched). 2. Removes TorrentInfo and TorrentBlock 3. Replaces the BlockManager.getSingle call in readObject with a getLocal, resuling in one less RPC call to the BlockManagerMasterActor to find the location of the block. 4. Removes the metadata block, resulting in one less block to fetch. 5. Removes an extra memory copy for deserialization (by using Java's SequenceInputStream).
1 parent 8257733 commit c1185cd

File tree

3 files changed

+168
-239
lines changed

3 files changed

+168
-239
lines changed

core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,19 @@ import org.apache.spark.annotation.DeveloperApi
3232
*/
3333
@DeveloperApi
3434
trait BroadcastFactory {
35+
3536
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
37+
38+
/**
39+
* Creates a new broadcast variable.
40+
*
41+
* @param value value to broadcast
42+
* @param isLocal whether we are in local mode (single JVM process)
43+
* @param id unique id representing this broadcast variable
44+
*/
3645
def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T]
46+
3747
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit
48+
3849
def stop(): Unit
3950
}

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala

Lines changed: 97 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
package org.apache.spark.broadcast
1919

2020
import java.io._
21+
import java.nio.ByteBuffer
2122

23+
import scala.collection.JavaConversions.asJavaEnumeration
2224
import scala.reflect.ClassTag
2325
import scala.util.Random
2426

@@ -27,41 +29,87 @@ import org.apache.spark.io.CompressionCodec
2729
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
2830

2931
/**
30-
* A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
31-
* protocol to do a distributed transfer of the broadcasted data to the executors.
32-
* The mechanism is as follows. The driver divides the serializes the broadcasted data,
33-
* divides it into smaller chunks, and stores them in the BlockManager of the driver.
34-
* These chunks are reported to the BlockManagerMaster so that all the executors can
35-
* learn the location of those chunks. The first time the broadcast variable (sent as
36-
* part of task) is deserialized at a executor, all the chunks are fetched using
37-
* the BlockManager. When all the chunks are fetched (initially from the driver's
38-
* BlockManager), they are combined and deserialized to recreate the broadcasted data.
39-
* However, the chunks are also stored in the BlockManager and reported to the
40-
* BlockManagerMaster. As more executors fetch the chunks, BlockManagerMaster learns
41-
* multiple locations for each chunk. Hence, subsequent fetches of each chunk will be
42-
* made to other executors who already have those chunks, resulting in a distributed
43-
* fetching. This prevents the driver from being the bottleneck in sending out multiple
44-
* copies of the broadcast data (one per executor) as done by the
45-
* [[org.apache.spark.broadcast.HttpBroadcast]].
32+
* A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]].
33+
*
34+
* The mechanism is as follows:
35+
*
36+
* The driver divides the serialized object into small chunks and
37+
* stores those chunks in the BlockManager of the driver.
38+
*
39+
* On each executor, the executor first attempts to fetch the object from its BlockManager. If
40+
* it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or
41+
* other executors if available. Once it gets the chunks, it puts the chunks in its own
42+
* BlockManager, ready for other executors to fetch from.
43+
*
44+
* This prevents the driver from being the bottleneck in sending out multiple copies of the
45+
* broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]].
46+
*
47+
* @param obj object to broadcast
48+
* @param isLocal whether Spark is running in local mode (single JVM process).
49+
* @param id A unique identifier for the broadcast variable.
4650
*/
4751
private[spark] class TorrentBroadcast[T: ClassTag](
48-
@transient var value_ : T, isLocal: Boolean, id: Long)
52+
obj : T,
53+
@transient private val isLocal: Boolean,
54+
id: Long)
4955
extends Broadcast[T](id) with Logging with Serializable {
5056

51-
override protected def getValue() = value_
57+
override protected def getValue() = _value
58+
59+
/**
60+
* Value of the broadcast object. On driver, this is set directly by the constructor.
61+
* On executors, this is reconstructed by [[readObject]], which builds this value by reading
62+
* blocks from the driver and/or other executors.
63+
*/
64+
@transient private var _value: T = obj
65+
66+
/** Total number of blocks this broadcast variable contains. */
67+
private val numBlocks: Int = writeBlocks()
5268

5369
private val broadcastId = BroadcastBlockId(id)
5470

55-
SparkEnv.get.blockManager.putSingle(
56-
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
71+
/**
72+
* Divide the object into multiple blocks and put those blocks in the block manager.
73+
*
74+
* @return number of blocks this broadcast variable is divided into
75+
*/
76+
private def writeBlocks(): Int = {
77+
val blocks = TorrentBroadcast.blockifyObject(_value)
78+
blocks.zipWithIndex.foreach { case (block, i) =>
79+
// TODO: Use putBytes directly.
80+
SparkEnv.get.blockManager.putSingle(
81+
BroadcastBlockId(id, "piece" + i),
82+
blocks(i),
83+
StorageLevel.MEMORY_AND_DISK_SER,
84+
tellMaster = true)
85+
}
86+
blocks.length
87+
}
5788

58-
@transient private var arrayOfBlocks: Array[TorrentBlock] = null
59-
@transient private var totalBlocks = -1
60-
@transient private var totalBytes = -1
61-
@transient private var hasBlocks = 0
89+
/** Fetch torrent blocks from the driver and/or other executors. */
90+
private def readBlocks(): Array[Array[Byte]] = {
91+
// Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
92+
// to the driver, so other executors can pull these thunks from this executor as well.
93+
var numBlocksAvailable = 0
94+
val blocks = new Array[Array[Byte]](numBlocks)
6295

63-
if (!isLocal) {
64-
sendBroadcast()
96+
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
97+
val pieceId = BroadcastBlockId(id, "piece" + pid)
98+
SparkEnv.get.blockManager.getSingle(pieceId) match {
99+
case Some(x) =>
100+
blocks(pid) = x.asInstanceOf[Array[Byte]]
101+
numBlocksAvailable += 1
102+
SparkEnv.get.blockManager.putBytes(
103+
pieceId,
104+
ByteBuffer.wrap(blocks(pid)),
105+
StorageLevel.MEMORY_AND_DISK_SER,
106+
tellMaster = true)
107+
108+
case None =>
109+
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
110+
}
111+
}
112+
blocks
65113
}
66114

67115
/**
@@ -79,26 +127,6 @@ private[spark] class TorrentBroadcast[T: ClassTag](
79127
TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
80128
}
81129

82-
private def sendBroadcast() {
83-
val tInfo = TorrentBroadcast.blockifyObject(value_)
84-
totalBlocks = tInfo.totalBlocks
85-
totalBytes = tInfo.totalBytes
86-
hasBlocks = tInfo.totalBlocks
87-
88-
// Store meta-info
89-
val metaId = BroadcastBlockId(id, "meta")
90-
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
91-
SparkEnv.get.blockManager.putSingle(
92-
metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
93-
94-
// Store individual pieces
95-
for (i <- 0 until totalBlocks) {
96-
val pieceId = BroadcastBlockId(id, "piece" + i)
97-
SparkEnv.get.blockManager.putSingle(
98-
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
99-
}
100-
}
101-
102130
/** Used by the JVM when serializing this object. */
103131
private def writeObject(out: ObjectOutputStream) {
104132
assertValid()
@@ -109,99 +137,30 @@ private[spark] class TorrentBroadcast[T: ClassTag](
109137
private def readObject(in: ObjectInputStream) {
110138
in.defaultReadObject()
111139
TorrentBroadcast.synchronized {
112-
SparkEnv.get.blockManager.getSingle(broadcastId) match {
140+
SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
113141
case Some(x) =>
114-
value_ = x.asInstanceOf[T]
142+
_value = x.asInstanceOf[T]
115143

116144
case None =>
117-
val start = System.nanoTime
118145
logInfo("Started reading broadcast variable " + id)
119-
120-
// Initialize @transient variables that will receive garbage values from the master.
121-
resetWorkerVariables()
122-
123-
if (receiveBroadcast()) {
124-
value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
125-
126-
/* Store the merged copy in cache so that the next worker doesn't need to rebuild it.
127-
* This creates a trade-off between memory usage and latency. Storing copy doubles
128-
* the memory footprint; not storing doubles deserialization cost. Also,
129-
* this does not need to be reported to BlockManagerMaster since other executors
130-
* does not need to access this block (they only need to fetch the chunks,
131-
* which are reported).
132-
*/
133-
SparkEnv.get.blockManager.putSingle(
134-
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
135-
136-
// Remove arrayOfBlocks from memory once value_ is on local cache
137-
resetWorkerVariables()
138-
} else {
139-
logError("Reading broadcast variable " + id + " failed")
140-
}
141-
142-
val time = (System.nanoTime - start) / 1e9
146+
val start = System.nanoTime()
147+
val blocks = readBlocks()
148+
val time = (System.nanoTime() - start) / 1e9
143149
logInfo("Reading broadcast variable " + id + " took " + time + " s")
144-
}
145-
}
146-
}
147-
148-
private def resetWorkerVariables() {
149-
arrayOfBlocks = null
150-
totalBytes = -1
151-
totalBlocks = -1
152-
hasBlocks = 0
153-
}
154-
155-
private def receiveBroadcast(): Boolean = {
156-
// Receive meta-info about the size of broadcast data,
157-
// the number of chunks it is divided into, etc.
158-
val metaId = BroadcastBlockId(id, "meta")
159-
var attemptId = 10
160-
while (attemptId > 0 && totalBlocks == -1) {
161-
SparkEnv.get.blockManager.getSingle(metaId) match {
162-
case Some(x) =>
163-
val tInfo = x.asInstanceOf[TorrentInfo]
164-
totalBlocks = tInfo.totalBlocks
165-
totalBytes = tInfo.totalBytes
166-
arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
167-
hasBlocks = 0
168-
169-
case None =>
170-
Thread.sleep(500)
171-
}
172-
attemptId -= 1
173-
}
174-
175-
if (totalBlocks == -1) {
176-
return false
177-
}
178150

179-
/*
180-
* Fetch actual chunks of data. Note that all these chunks are stored in
181-
* the BlockManager and reported to the master, so that other executors
182-
* can find out and pull the chunks from this executor.
183-
*/
184-
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
185-
for (pid <- recvOrder) {
186-
val pieceId = BroadcastBlockId(id, "piece" + pid)
187-
SparkEnv.get.blockManager.getSingle(pieceId) match {
188-
case Some(x) =>
189-
arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
190-
hasBlocks += 1
151+
_value = TorrentBroadcast.unBlockifyObject[T](blocks)
152+
// Store the merged copy in BlockManager so other tasks on this executor doesn't
153+
// need to re-fetch it.
191154
SparkEnv.get.blockManager.putSingle(
192-
pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
193-
194-
case None =>
195-
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
155+
broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
196156
}
197157
}
198-
199-
hasBlocks == totalBlocks
200158
}
201-
202159
}
203160

204-
private[broadcast] object TorrentBroadcast extends Logging {
161+
162+
private object TorrentBroadcast extends Logging {
163+
/** Size of each block. Default value is 4MB. */
205164
private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
206165
private var initialized = false
207166
private var conf: SparkConf = null
@@ -223,52 +182,37 @@ private[broadcast] object TorrentBroadcast extends Logging {
223182
initialized = false
224183
}
225184

226-
def blockifyObject[T: ClassTag](obj: T): TorrentInfo = {
185+
def blockifyObject[T: ClassTag](obj: T): Array[Array[Byte]] = {
186+
// TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks
187+
// so we don't need to do the extra memory copy.
227188
val bos = new ByteArrayOutputStream()
228189
val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos
229190
val ser = SparkEnv.get.serializer.newInstance()
230191
val serOut = ser.serializeStream(out)
231192
serOut.writeObject[T](obj).close()
232193
val byteArray = bos.toByteArray
233194
val bais = new ByteArrayInputStream(byteArray)
195+
val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE).toInt
196+
val blocks = new Array[Array[Byte]](numBlocks)
234197

235-
var blockNum = byteArray.length / BLOCK_SIZE
236-
if (byteArray.length % BLOCK_SIZE != 0) {
237-
blockNum += 1
238-
}
239-
240-
val blocks = new Array[TorrentBlock](blockNum)
241198
var blockId = 0
242-
243199
for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
244200
val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
245201
val tempByteArray = new Array[Byte](thisBlockSize)
246202
bais.read(tempByteArray, 0, thisBlockSize)
247203

248-
blocks(blockId) = new TorrentBlock(blockId, tempByteArray)
204+
blocks(blockId) = tempByteArray
249205
blockId += 1
250206
}
251207
bais.close()
252-
253-
val info = TorrentInfo(blocks, blockNum, byteArray.length)
254-
info.hasBlocks = blockNum
255-
info
208+
blocks
256209
}
257210

258-
def unBlockifyObject[T: ClassTag](
259-
arrayOfBlocks: Array[TorrentBlock],
260-
totalBytes: Int,
261-
totalBlocks: Int): T = {
262-
val retByteArray = new Array[Byte](totalBytes)
263-
for (i <- 0 until totalBlocks) {
264-
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
265-
i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
266-
}
211+
def unBlockifyObject[T: ClassTag](blocks: Array[Array[Byte]]): T = {
212+
val is = new SequenceInputStream(
213+
asJavaEnumeration(blocks.iterator.map(block => new ByteArrayInputStream(block))))
214+
val in: InputStream = if (compress) compressionCodec.compressedInputStream(is) else is
267215

268-
val in: InputStream = {
269-
val arrIn = new ByteArrayInputStream(retByteArray)
270-
if (compress) compressionCodec.compressedInputStream(arrIn) else arrIn
271-
}
272216
val ser = SparkEnv.get.serializer.newInstance()
273217
val serIn = ser.deserializeStream(in)
274218
val obj = serIn.readObject[T]()
@@ -284,17 +228,3 @@ private[broadcast] object TorrentBroadcast extends Logging {
284228
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
285229
}
286230
}
287-
288-
private[broadcast] case class TorrentBlock(
289-
blockID: Int,
290-
byteArray: Array[Byte])
291-
extends Serializable
292-
293-
private[broadcast] case class TorrentInfo(
294-
@transient arrayOfBlocks: Array[TorrentBlock],
295-
totalBlocks: Int,
296-
totalBytes: Int)
297-
extends Serializable {
298-
299-
@transient var hasBlocks = 0
300-
}

0 commit comments

Comments
 (0)