Skip to content

Commit 0d8ed5b

Browse files
committed
Added getBytes to BlockManager and uses that in TorrentBroadcast.
1 parent 2d6a5fb commit 0d8ed5b

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,18 +90,19 @@ private[spark] class TorrentBroadcast[T: ClassTag](
9090
private def readBlocks(): Array[ByteBuffer] = {
9191
// Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
9292
// to the driver, so other executors can pull these chunks from this executor as well.
93-
var numBlocksAvailable = 0
9493
val blocks = new Array[ByteBuffer](numBlocks)
9594

9695
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
9796
val pieceId = BroadcastBlockId(id, "piece" + pid)
98-
SparkEnv.get.blockManager.getRemoteBytes(pieceId) match {
99-
case Some(x) =>
100-
blocks(pid) = x.asInstanceOf[ByteBuffer]
101-
numBlocksAvailable += 1
97+
// Note that we use getBytes rather than getRemoteBytes here because there is a chance
98+
// that previous attempts to fetch the broadcast blocks have already fetched some of the
99+
// blocks. In that case, some blocks would be available locally (on this executor).
100+
SparkEnv.get.blockManager.getBytes(pieceId) match {
101+
case Some(block) =>
102+
blocks(pid) = block
102103
SparkEnv.get.blockManager.putBytes(
103104
pieceId,
104-
blocks(pid),
105+
block,
105106
StorageLevel.MEMORY_AND_DISK_SER,
106107
tellMaster = true)
107108

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,16 @@ private[spark] class BlockManager(
517517
None
518518
}
519519

520+
def getBytes(blockId: BlockId): Option[ByteBuffer] = {
521+
val local = getLocalBytes(blockId)
522+
if (local.isDefined) {
523+
local
524+
} else {
525+
val remote = getRemoteBytes(blockId)
526+
remote
527+
}
528+
}
529+
520530
/**
521531
* Get a block from the block manager (either local or remote).
522532
*/

0 commit comments

Comments
 (0)