Skip to content

Commit ba52e00

Browse files
committed
Refactor broadcast classes
1 parent c7ccef1 commit ba52e00

File tree

11 files changed

+169
-268
lines changed

11 files changed

+169
-268
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -641,13 +641,8 @@ class SparkContext(
641641
* Broadcast a read-only variable to the cluster, returning a
642642
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
643643
* The variable will be sent to each cluster only once.
644-
*
645-
* If `registerBlocks` is true, workers will notify driver about blocks they create
646-
* and these blocks will be dropped when `unpersist` method of the broadcast variable is called.
647644
*/
648-
def broadcast[T](value: T, registerBlocks: Boolean = false) = {
649-
env.broadcastManager.newBroadcast[T](value, isLocal, registerBlocks)
650-
}
645+
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
651646

652647
/**
653648
* Add a file to be downloaded with this Spark job on every node.

core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818
package org.apache.spark.broadcast
1919

2020
import java.io.Serializable
21-
import java.util.concurrent.atomic.AtomicLong
22-
23-
import org.apache.spark._
2421

2522
/**
2623
* A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable
@@ -53,56 +50,8 @@ import org.apache.spark._
5350
abstract class Broadcast[T](val id: Long) extends Serializable {
5451
def value: T
5552

56-
/**
57-
* Removes all blocks of this broadcast from memory (and disk if removeSource is true).
58-
*
59-
* @param removeSource Whether to remove data from disk as well.
60-
* Will cause errors if broadcast is accessed on workers afterwards
61-
* (e.g. in case of RDD re-computation due to executor failure).
62-
*/
63-
def unpersist(removeSource: Boolean = false)
64-
6553
// We cannot have an abstract readObject here due to some weird issues with
6654
// readObject having to be 'private' in sub-classes.
6755

6856
override def toString = "Broadcast(" + id + ")"
6957
}
70-
71-
private[spark]
72-
class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager)
73-
extends Logging with Serializable {
74-
75-
private var initialized = false
76-
private var broadcastFactory: BroadcastFactory = null
77-
78-
initialize()
79-
80-
// Called by SparkContext or Executor before using Broadcast
81-
private def initialize() {
82-
synchronized {
83-
if (!initialized) {
84-
val broadcastFactoryClass = conf.get(
85-
"spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
86-
87-
broadcastFactory =
88-
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
89-
90-
// Initialize appropriate BroadcastFactory and BroadcastObject
91-
broadcastFactory.initialize(isDriver, conf, securityManager)
92-
93-
initialized = true
94-
}
95-
}
96-
}
97-
98-
def stop() {
99-
broadcastFactory.stop()
100-
}
101-
102-
private val nextBroadcastId = new AtomicLong(0)
103-
104-
def newBroadcast[T](value_ : T, isLocal: Boolean, registerBlocks: Boolean) =
105-
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement(), registerBlocks)
106-
107-
def isDriver = _isDriver
108-
}

core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,6 @@ import org.apache.spark.SparkConf
2828
*/
2929
trait BroadcastFactory {
3030
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
31-
def newBroadcast[T](value: T, isLocal: Boolean, id: Long, registerBlocks: Boolean): Broadcast[T]
31+
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
3232
def stop(): Unit
3333
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.broadcast
19+
20+
import java.util.concurrent.atomic.AtomicLong
21+
22+
import org.apache.spark._
23+
24+
private[spark] class BroadcastManager(
25+
val isDriver: Boolean,
26+
conf: SparkConf,
27+
securityManager: SecurityManager)
28+
extends Logging with Serializable {
29+
30+
private var initialized = false
31+
private var broadcastFactory: BroadcastFactory = null
32+
33+
initialize()
34+
35+
// Called by SparkContext or Executor before using Broadcast
36+
private def initialize() {
37+
synchronized {
38+
if (!initialized) {
39+
val broadcastFactoryClass =
40+
conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
41+
42+
broadcastFactory =
43+
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
44+
45+
// Initialize appropriate BroadcastFactory and BroadcastObject
46+
broadcastFactory.initialize(isDriver, conf, securityManager)
47+
48+
initialized = true
49+
}
50+
}
51+
}
52+
53+
def stop() {
54+
broadcastFactory.stop()
55+
}
56+
57+
private val nextBroadcastId = new AtomicLong(0)
58+
59+
def newBroadcast[T](value_ : T, isLocal: Boolean) = {
60+
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
61+
}
62+
63+
}

core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,11 @@ import org.apache.spark.io.CompressionCodec
2929
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
3030
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
3131

32-
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean)
32+
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
3333
extends Broadcast[T](id) with Logging with Serializable {
3434

3535
def value = value_
3636

37-
def unpersist(removeSource: Boolean) {
38-
HttpBroadcast.synchronized {
39-
SparkEnv.get.blockManager.master.removeBlock(blockId)
40-
SparkEnv.get.blockManager.removeBlock(blockId)
41-
}
42-
43-
if (removeSource) {
44-
HttpBroadcast.synchronized {
45-
HttpBroadcast.cleanupById(id)
46-
}
47-
}
48-
}
49-
5037
def blockId = BroadcastBlockId(id)
5138

5239
HttpBroadcast.synchronized {
@@ -67,7 +54,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
6754
logInfo("Started reading broadcast variable " + id)
6855
val start = System.nanoTime
6956
value_ = HttpBroadcast.read[T](id)
70-
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks)
57+
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
7158
val time = (System.nanoTime - start) / 1e9
7259
logInfo("Reading broadcast variable " + id + " took " + time + " s")
7360
}
@@ -76,20 +63,6 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
7663
}
7764
}
7865

79-
/**
80-
* A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium.
81-
*/
82-
class HttpBroadcastFactory extends BroadcastFactory {
83-
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
84-
HttpBroadcast.initialize(isDriver, conf, securityMgr)
85-
}
86-
87-
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) =
88-
new HttpBroadcast[T](value_, isLocal, id, registerBlocks)
89-
90-
def stop() { HttpBroadcast.stop() }
91-
}
92-
9366
private object HttpBroadcast extends Logging {
9467
private var initialized = false
9568

@@ -149,10 +122,8 @@ private object HttpBroadcast extends Logging {
149122
logInfo("Broadcast server started at " + serverUri)
150123
}
151124

152-
def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name)
153-
154125
def write(id: Long, value: Any) {
155-
val file = getFile(id)
126+
val file = new File(broadcastDir, BroadcastBlockId(id).name)
156127
val out: OutputStream = {
157128
if (compress) {
158129
compressionCodec.compressedOutputStream(new FileOutputStream(file))
@@ -198,30 +169,20 @@ private object HttpBroadcast extends Logging {
198169
obj
199170
}
200171

201-
def deleteFile(fileName: String) {
202-
try {
203-
new File(fileName).delete()
204-
logInfo("Deleted broadcast file '" + fileName + "'")
205-
} catch {
206-
case e: Exception => logWarning("Could not delete broadcast file '" + fileName + "'", e)
207-
}
208-
}
209-
210172
def cleanup(cleanupTime: Long) {
211173
val iterator = files.internalMap.entrySet().iterator()
212174
while(iterator.hasNext) {
213175
val entry = iterator.next()
214176
val (file, time) = (entry.getKey, entry.getValue)
215177
if (time < cleanupTime) {
216-
iterator.remove()
217-
deleteFile(file)
178+
try {
179+
iterator.remove()
180+
new File(file.toString).delete()
181+
logInfo("Deleted broadcast file '" + file + "'")
182+
} catch {
183+
case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e)
184+
}
218185
}
219186
}
220187
}
221-
222-
def cleanupById(id: Long) {
223-
val file = getFile(id).getAbsolutePath
224-
files.internalMap.remove(file)
225-
deleteFile(file)
226-
}
227188
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.broadcast
19+
20+
import org.apache.spark.{SecurityManager, SparkConf}
21+
22+
/**
23+
* A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium.
24+
*/
25+
class HttpBroadcastFactory extends BroadcastFactory {
26+
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
27+
HttpBroadcast.initialize(isDriver, conf, securityMgr)
28+
}
29+
30+
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
31+
new HttpBroadcast[T](value_, isLocal, id)
32+
33+
def stop() { HttpBroadcast.stop() }
34+
}

0 commit comments

Comments
 (0)