Skip to content

Commit 34f436f

Browse files
committed
Generalize BroadcastBlockId to remove BroadcastHelperBlockId
Rather than having a special purpose BroadcastHelperBlockId just for TorrentBroadcast, we now have a single BroadcastBlockId that has a possibly empty field. This simplifies broadcast clean-up because now we only have to look for one type of block. This commit also simplifies BlockId JSON de/serialization in general by parsing the name through regex with apply.
1 parent 0d17060 commit 34f436f

File tree

6 files changed

+54
-140
lines changed

6 files changed

+54
-140
lines changed

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import scala.math
2323
import scala.util.Random
2424

2525
import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
26-
import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
26+
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
2727
import org.apache.spark.util.Utils
2828

2929
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
@@ -54,7 +54,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo
5454
hasBlocks = tInfo.totalBlocks
5555

5656
// Store meta-info
57-
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
57+
val metaId = BroadcastBlockId(id, "meta")
5858
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
5959
TorrentBroadcast.synchronized {
6060
SparkEnv.get.blockManager.putSingle(
@@ -63,7 +63,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo
6363

6464
// Store individual pieces
6565
for (i <- 0 until totalBlocks) {
66-
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
66+
val pieceId = BroadcastBlockId(id, "piece" + i)
6767
TorrentBroadcast.synchronized {
6868
SparkEnv.get.blockManager.putSingle(
6969
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
@@ -131,7 +131,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo
131131

132132
def receiveBroadcast(): Boolean = {
133133
// Receive meta-info
134-
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
134+
val metaId = BroadcastBlockId(id, "meta")
135135
var attemptId = 10
136136
while (attemptId > 0 && totalBlocks == -1) {
137137
TorrentBroadcast.synchronized {
@@ -156,7 +156,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo
156156
// Receive actual blocks
157157
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
158158
for (pid <- recvOrder) {
159-
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
159+
val pieceId = BroadcastBlockId(id, "piece" + pid)
160160
TorrentBroadcast.synchronized {
161161
SparkEnv.get.blockManager.getSingle(pieceId) match {
162162
case Some(x) =>

core/src/main/scala/org/apache/spark/storage/BlockId.scala

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ private[spark] sealed abstract class BlockId {
3434
def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
3535
def isRDD = isInstanceOf[RDDBlockId]
3636
def isShuffle = isInstanceOf[ShuffleBlockId]
37-
def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId]
37+
def isBroadcast = isInstanceOf[BroadcastBlockId]
3838

3939
override def toString = name
4040
override def hashCode = name.hashCode
@@ -48,18 +48,15 @@ private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockI
4848
def name = "rdd_" + rddId + "_" + splitIndex
4949
}
5050

51-
private[spark]
52-
case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
51+
private[spark] case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int)
52+
extends BlockId {
5353
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
5454
}
5555

56+
// Leave field as an instance variable to avoid matching on it
5657
private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
57-
def name = "broadcast_" + broadcastId
58-
}
59-
60-
private[spark]
61-
case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
62-
def name = broadcastId.name + "_" + hType
58+
var field = ""
59+
def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field)
6360
}
6461

6562
private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
@@ -80,11 +77,19 @@ private[spark] case class TestBlockId(id: String) extends BlockId {
8077
def name = "test_" + id
8178
}
8279

80+
private[spark] object BroadcastBlockId {
81+
def apply(broadcastId: Long, field: String) = {
82+
val blockId = new BroadcastBlockId(broadcastId)
83+
blockId.field = field
84+
blockId
85+
}
86+
}
87+
8388
private[spark] object BlockId {
8489
val RDD = "rdd_([0-9]+)_([0-9]+)".r
8590
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
8691
val BROADCAST = "broadcast_([0-9]+)".r
87-
val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
92+
val BROADCAST_FIELD = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
8893
val TASKRESULT = "taskresult_([0-9]+)".r
8994
val STREAM = "input-([0-9]+)-([0-9]+)".r
9095
val TEST = "test_(.*)".r
@@ -97,8 +102,8 @@ private[spark] object BlockId {
97102
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
98103
case BROADCAST(broadcastId) =>
99104
BroadcastBlockId(broadcastId.toLong)
100-
case BROADCAST_HELPER(broadcastId, hType) =>
101-
BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType)
105+
case BROADCAST_FIELD(broadcastId, field) =>
106+
BroadcastBlockId(broadcastId.toLong, field)
102107
case TASKRESULT(taskId) =>
103108
TaskResultBlockId(taskId.toLong)
104109
case STREAM(streamId, uniqueId) =>

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -827,9 +827,8 @@ private[spark] class BlockManager(
827827
*/
828828
def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) {
829829
logInfo("Removing broadcast " + broadcastId)
830-
val blocksToRemove = blockInfo.keys.filter(_.isBroadcast).collect {
830+
val blocksToRemove = blockInfo.keys.collect {
831831
case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid
832-
case bid: BroadcastHelperBlockId if bid.broadcastId.broadcastId == broadcastId => bid
833832
}
834833
blocksToRemove.foreach { blockId => removeBlock(blockId, removeFromDriver) }
835834
}

core/src/main/scala/org/apache/spark/util/JsonProtocol.scala

Lines changed: 2 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ private[spark] object JsonProtocol {
195195
taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing)
196196
val updatedBlocks = taskMetrics.updatedBlocks.map { blocks =>
197197
JArray(blocks.toList.map { case (id, status) =>
198-
("Block ID" -> blockIdToJson(id)) ~
198+
("Block ID" -> id.toString) ~
199199
("Status" -> blockStatusToJson(status))
200200
})
201201
}.getOrElse(JNothing)
@@ -284,35 +284,6 @@ private[spark] object JsonProtocol {
284284
("Replication" -> storageLevel.replication)
285285
}
286286

287-
def blockIdToJson(blockId: BlockId): JValue = {
288-
val blockType = Utils.getFormattedClassName(blockId)
289-
val json: JObject = blockId match {
290-
case rddBlockId: RDDBlockId =>
291-
("RDD ID" -> rddBlockId.rddId) ~
292-
("Split Index" -> rddBlockId.splitIndex)
293-
case shuffleBlockId: ShuffleBlockId =>
294-
("Shuffle ID" -> shuffleBlockId.shuffleId) ~
295-
("Map ID" -> shuffleBlockId.mapId) ~
296-
("Reduce ID" -> shuffleBlockId.reduceId)
297-
case broadcastBlockId: BroadcastBlockId =>
298-
"Broadcast ID" -> broadcastBlockId.broadcastId
299-
case broadcastHelperBlockId: BroadcastHelperBlockId =>
300-
("Broadcast Block ID" -> blockIdToJson(broadcastHelperBlockId.broadcastId)) ~
301-
("Helper Type" -> broadcastHelperBlockId.hType)
302-
case taskResultBlockId: TaskResultBlockId =>
303-
"Task ID" -> taskResultBlockId.taskId
304-
case streamBlockId: StreamBlockId =>
305-
("Stream ID" -> streamBlockId.streamId) ~
306-
("Unique ID" -> streamBlockId.uniqueId)
307-
case tempBlockId: TempBlockId =>
308-
val uuid = UUIDToJson(tempBlockId.id)
309-
"Temp ID" -> uuid
310-
case testBlockId: TestBlockId =>
311-
"Test ID" -> testBlockId.id
312-
}
313-
("Type" -> blockType) ~ json
314-
}
315-
316287
def blockStatusToJson(blockStatus: BlockStatus): JValue = {
317288
val storageLevel = storageLevelToJson(blockStatus.storageLevel)
318289
("Storage Level" -> storageLevel) ~
@@ -513,7 +484,7 @@ private[spark] object JsonProtocol {
513484
Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson)
514485
metrics.updatedBlocks = Utils.jsonOption(json \ "Updated Blocks").map { value =>
515486
value.extract[List[JValue]].map { block =>
516-
val id = blockIdFromJson(block \ "Block ID")
487+
val id = BlockId((block \ "Block ID").extract[String])
517488
val status = blockStatusFromJson(block \ "Status")
518489
(id, status)
519490
}
@@ -616,50 +587,6 @@ private[spark] object JsonProtocol {
616587
StorageLevel(useDisk, useMemory, deserialized, replication)
617588
}
618589

619-
def blockIdFromJson(json: JValue): BlockId = {
620-
val rddBlockId = Utils.getFormattedClassName(RDDBlockId)
621-
val shuffleBlockId = Utils.getFormattedClassName(ShuffleBlockId)
622-
val broadcastBlockId = Utils.getFormattedClassName(BroadcastBlockId)
623-
val broadcastHelperBlockId = Utils.getFormattedClassName(BroadcastHelperBlockId)
624-
val taskResultBlockId = Utils.getFormattedClassName(TaskResultBlockId)
625-
val streamBlockId = Utils.getFormattedClassName(StreamBlockId)
626-
val tempBlockId = Utils.getFormattedClassName(TempBlockId)
627-
val testBlockId = Utils.getFormattedClassName(TestBlockId)
628-
629-
(json \ "Type").extract[String] match {
630-
case `rddBlockId` =>
631-
val rddId = (json \ "RDD ID").extract[Int]
632-
val splitIndex = (json \ "Split Index").extract[Int]
633-
new RDDBlockId(rddId, splitIndex)
634-
case `shuffleBlockId` =>
635-
val shuffleId = (json \ "Shuffle ID").extract[Int]
636-
val mapId = (json \ "Map ID").extract[Int]
637-
val reduceId = (json \ "Reduce ID").extract[Int]
638-
new ShuffleBlockId(shuffleId, mapId, reduceId)
639-
case `broadcastBlockId` =>
640-
val broadcastId = (json \ "Broadcast ID").extract[Long]
641-
new BroadcastBlockId(broadcastId)
642-
case `broadcastHelperBlockId` =>
643-
val broadcastBlockId =
644-
blockIdFromJson(json \ "Broadcast Block ID").asInstanceOf[BroadcastBlockId]
645-
val hType = (json \ "Helper Type").extract[String]
646-
new BroadcastHelperBlockId(broadcastBlockId, hType)
647-
case `taskResultBlockId` =>
648-
val taskId = (json \ "Task ID").extract[Long]
649-
new TaskResultBlockId(taskId)
650-
case `streamBlockId` =>
651-
val streamId = (json \ "Stream ID").extract[Int]
652-
val uniqueId = (json \ "Unique ID").extract[Long]
653-
new StreamBlockId(streamId, uniqueId)
654-
case `tempBlockId` =>
655-
val tempId = UUIDFromJson(json \ "Temp ID")
656-
new TempBlockId(tempId)
657-
case `testBlockId` =>
658-
val testId = (json \ "Test ID").extract[String]
659-
new TestBlockId(testId)
660-
}
661-
}
662-
663590
def blockStatusFromJson(json: JValue): BlockStatus = {
664591
val storageLevel = storageLevelFromJson(json \ "Storage Level")
665592
val memorySize = (json \ "Memory Size").extract[Long]

core/src/test/scala/org/apache/spark/BroadcastSuite.scala

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.scalatest.FunSuite
2121

2222
import org.apache.spark.storage._
2323
import org.apache.spark.broadcast.HttpBroadcast
24-
import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId}
24+
import org.apache.spark.storage.BroadcastBlockId
2525

2626
class BroadcastSuite extends FunSuite with LocalSparkContext {
2727

@@ -102,23 +102,22 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
102102
* are present only on the expected nodes.
103103
*/
104104
private def testUnpersistHttpBroadcast(numSlaves: Int, removeFromDriver: Boolean) {
105-
def getBlockIds(id: Long) = Seq[BlockId](BroadcastBlockId(id))
105+
def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id))
106106

107107
// Verify that the broadcast file is created, and blocks are persisted only on the driver
108-
def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) {
108+
def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
109109
assert(blockIds.size === 1)
110-
val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId]
111-
val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0)
110+
val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0)
112111
assert(levels.size === 1)
113112
levels.head match { case (bm, level) =>
114113
assert(bm.executorId === "<driver>")
115114
assert(level === StorageLevel.MEMORY_AND_DISK)
116115
}
117-
assert(HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists)
116+
assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists)
118117
}
119118

120119
// Verify that blocks are persisted in both the executors and the driver
121-
def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) {
120+
def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
122121
assert(blockIds.size === 1)
123122
val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0)
124123
assert(levels.size === numSlaves + 1)
@@ -129,12 +128,11 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
129128

130129
// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
131130
// is true. In the latter case, also verify that the broadcast file is deleted on the driver.
132-
def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) {
131+
def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
133132
assert(blockIds.size === 1)
134-
val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId]
135-
val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0)
133+
val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0)
136134
assert(levels.size === (if (removeFromDriver) 0 else 1))
137-
assert(removeFromDriver === !HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists)
135+
assert(removeFromDriver === !HttpBroadcast.getFile(blockIds.head.broadcastId).exists)
138136
}
139137

140138
testUnpersistBroadcast(numSlaves, httpConf, getBlockIds, afterCreation,
@@ -151,14 +149,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
151149
private def testUnpersistTorrentBroadcast(numSlaves: Int, removeFromDriver: Boolean) {
152150
def getBlockIds(id: Long) = {
153151
val broadcastBlockId = BroadcastBlockId(id)
154-
val metaBlockId = BroadcastHelperBlockId(broadcastBlockId, "meta")
152+
val metaBlockId = BroadcastBlockId(id, "meta")
155153
// Assume broadcast value is small enough to fit into 1 piece
156-
val pieceBlockId = BroadcastHelperBlockId(broadcastBlockId, "piece0")
157-
Seq[BlockId](broadcastBlockId, metaBlockId, pieceBlockId)
154+
val pieceBlockId = BroadcastBlockId(id, "piece0")
155+
Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId)
158156
}
159157

160158
// Verify that blocks are persisted only on the driver
161-
def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) {
159+
def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
162160
blockIds.foreach { blockId =>
163161
val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0)
164162
assert(levels.size === 1)
@@ -170,27 +168,26 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
170168
}
171169

172170
// Verify that blocks are persisted in both the executors and the driver
173-
def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) {
171+
def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
174172
blockIds.foreach { blockId =>
175173
val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0)
176-
blockId match {
177-
case BroadcastHelperBlockId(_, "meta") =>
178-
// Meta data is only on the driver
179-
assert(levels.size === 1)
180-
levels.head match { case (bm, _) => assert(bm.executorId === "<driver>") }
181-
case _ =>
182-
// Other blocks are on both the executors and the driver
183-
assert(levels.size === numSlaves + 1)
184-
levels.foreach { case (_, level) =>
185-
assert(level === StorageLevel.MEMORY_AND_DISK)
186-
}
174+
if (blockId.field == "meta") {
175+
// Meta data is only on the driver
176+
assert(levels.size === 1)
177+
levels.head match { case (bm, _) => assert(bm.executorId === "<driver>") }
178+
} else {
179+
// Other blocks are on both the executors and the driver
180+
assert(levels.size === numSlaves + 1)
181+
levels.foreach { case (_, level) =>
182+
assert(level === StorageLevel.MEMORY_AND_DISK)
183+
}
187184
}
188185
}
189186
}
190187

191188
// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
192189
// is true.
193-
def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) {
190+
def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
194191
val expectedNumBlocks = if (removeFromDriver) 0 else 1
195192
var waitTimeMs = 1000L
196193
blockIds.foreach { blockId =>
@@ -217,10 +214,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
217214
private def testUnpersistBroadcast(
218215
numSlaves: Int,
219216
broadcastConf: SparkConf,
220-
getBlockIds: Long => Seq[BlockId],
221-
afterCreation: (Seq[BlockId], BlockManagerMaster) => Unit,
222-
afterUsingBroadcast: (Seq[BlockId], BlockManagerMaster) => Unit,
223-
afterUnpersist: (Seq[BlockId], BlockManagerMaster) => Unit,
217+
getBlockIds: Long => Seq[BroadcastBlockId],
218+
afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
219+
afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
220+
afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
224221
removeFromDriver: Boolean) {
225222

226223
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf)

core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,6 @@ class JsonProtocolSuite extends FunSuite {
104104
testTaskEndReason(TaskKilled)
105105
testTaskEndReason(ExecutorLostFailure)
106106
testTaskEndReason(UnknownReason)
107-
108-
// BlockId
109-
testBlockId(RDDBlockId(1, 2))
110-
testBlockId(ShuffleBlockId(1, 2, 3))
111-
testBlockId(BroadcastBlockId(1L))
112-
testBlockId(BroadcastHelperBlockId(BroadcastBlockId(2L), "Spark"))
113-
testBlockId(TaskResultBlockId(1L))
114-
testBlockId(StreamBlockId(1, 2L))
115-
testBlockId(TempBlockId(UUID.randomUUID()))
116107
}
117108

118109

@@ -167,11 +158,6 @@ class JsonProtocolSuite extends FunSuite {
167158
assertEquals(reason, newReason)
168159
}
169160

170-
private def testBlockId(blockId: BlockId) {
171-
val newBlockId = JsonProtocol.blockIdFromJson(JsonProtocol.blockIdToJson(blockId))
172-
blockId == newBlockId
173-
}
174-
175161

176162
/** -------------------------------- *
177163
| Util methods for comparing events |

0 commit comments

Comments
 (0)