Skip to content

Commit 104a89a

Browse files
committed
Fixed failing BroadcastSuite unit tests by introducing blocking for removeShuffle and removeBroadcast in BlockManager*
1 parent a430f06 commit 104a89a

14 files changed

+190
-103
lines changed

core/src/main/scala/org/apache/spark/ContextCleaner.scala

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
112112
logDebug("Got cleaning task " + task)
113113
referenceBuffer -= reference.get
114114
task match {
115-
case CleanRDD(rddId) => doCleanupRDD(rddId)
116-
case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId)
117-
case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId)
115+
case CleanRDD(rddId) => doCleanupRDD(rddId, blocking = false)
116+
case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId, blocking = false)
117+
case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId, blocking = false)
118118
}
119119
}
120120
} catch {
@@ -124,23 +124,23 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
124124
}
125125

126126
/** Perform RDD cleanup. */
127-
private def doCleanupRDD(rddId: Int) {
127+
private def doCleanupRDD(rddId: Int, blocking: Boolean) {
128128
try {
129129
logDebug("Cleaning RDD " + rddId)
130-
sc.unpersistRDD(rddId, blocking = false)
130+
sc.unpersistRDD(rddId, blocking)
131131
listeners.foreach(_.rddCleaned(rddId))
132132
logInfo("Cleaned RDD " + rddId)
133133
} catch {
134134
case t: Throwable => logError("Error cleaning RDD " + rddId, t)
135135
}
136136
}
137137

138-
/** Perform shuffle cleanup. */
139-
private def doCleanupShuffle(shuffleId: Int) {
138+
/** Perform shuffle cleanup, asynchronously. */
139+
private def doCleanupShuffle(shuffleId: Int, blocking: Boolean) {
140140
try {
141141
logDebug("Cleaning shuffle " + shuffleId)
142142
mapOutputTrackerMaster.unregisterShuffle(shuffleId)
143-
blockManagerMaster.removeShuffle(shuffleId)
143+
blockManagerMaster.removeShuffle(shuffleId, blocking)
144144
listeners.foreach(_.shuffleCleaned(shuffleId))
145145
logInfo("Cleaned shuffle " + shuffleId)
146146
} catch {
@@ -149,10 +149,10 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
149149
}
150150

151151
/** Perform broadcast cleanup. */
152-
private def doCleanupBroadcast(broadcastId: Long) {
152+
private def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) {
153153
try {
154154
logDebug("Cleaning broadcast " + broadcastId)
155-
broadcastManager.unbroadcast(broadcastId, removeFromDriver = true)
155+
broadcastManager.unbroadcast(broadcastId, true, blocking)
156156
listeners.foreach(_.broadcastCleaned(broadcastId))
157157
logInfo("Cleaned broadcast " + broadcastId)
158158
} catch {
@@ -164,18 +164,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
164164
private def broadcastManager = sc.env.broadcastManager
165165
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
166166

167-
// Used for testing
167+
// Used for testing, explicitly blocks until cleanup is completed
168168

169169
def cleanupRDD(rdd: RDD[_]) {
170-
doCleanupRDD(rdd.id)
170+
doCleanupRDD(rdd.id, blocking = true)
171171
}
172172

173173
def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) {
174-
doCleanupShuffle(shuffleDependency.shuffleId)
174+
doCleanupShuffle(shuffleDependency.shuffleId, blocking = true)
175175
}
176176

177177
def cleanupBroadcast[T](broadcast: Broadcast[T]) {
178-
doCleanupBroadcast(broadcast.id)
178+
doCleanupBroadcast(broadcast.id, blocking = true)
179179
}
180180
}
181181

core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,22 +61,31 @@ abstract class Broadcast[T](val id: Long) extends Serializable {
6161

6262
def value: T
6363

64+
/**
65+
* Asynchronously delete cached copies of this broadcast on the executors.
66+
* If the broadcast is used after this is called, it will need to be re-sent to each executor.
67+
*/
68+
def unpersist() {
69+
unpersist(blocking = false)
70+
}
71+
6472
/**
6573
* Delete cached copies of this broadcast on the executors. If the broadcast is used after
6674
* this is called, it will need to be re-sent to each executor.
75+
* @param blocking Whether to block until unpersisting has completed
6776
*/
68-
def unpersist()
77+
def unpersist(blocking: Boolean)
6978

7079
/**
7180
* Remove all persisted state associated with this broadcast on both the executors and
7281
* the driver.
7382
*/
74-
private[spark] def destroy() {
83+
private[spark] def destroy(blocking: Boolean) {
7584
_isValid = false
76-
onDestroy()
85+
onDestroy(blocking)
7786
}
7887

79-
protected def onDestroy()
88+
protected def onDestroy(blocking: Boolean)
8089

8190
/**
8291
* If this broadcast is no longer valid, throw an exception.

core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,6 @@ import org.apache.spark.SparkConf
2929
trait BroadcastFactory {
3030
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager)
3131
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
32-
def unbroadcast(id: Long, removeFromDriver: Boolean)
32+
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean)
3333
def stop()
3434
}

core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ private[spark] class BroadcastManager(
6060
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
6161
}
6262

63-
def unbroadcast(id: Long, removeFromDriver: Boolean) {
64-
broadcastFactory.unbroadcast(id, removeFromDriver)
63+
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
64+
broadcastFactory.unbroadcast(id, removeFromDriver, blocking)
6565
}
6666
}

core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
5050
/**
5151
* Remove all persisted state associated with this HTTP broadcast on the executors.
5252
*/
53-
def unpersist() {
54-
HttpBroadcast.unpersist(id, removeFromDriver = false)
53+
def unpersist(blocking: Boolean) {
54+
HttpBroadcast.unpersist(id, removeFromDriver = false, blocking)
5555
}
5656

57-
protected def onDestroy() {
58-
HttpBroadcast.unpersist(id, removeFromDriver = true)
57+
protected def onDestroy(blocking: Boolean) {
58+
HttpBroadcast.unpersist(id, removeFromDriver = true, blocking)
5959
}
6060

6161
// Used by the JVM when serializing this object
@@ -194,8 +194,8 @@ private[spark] object HttpBroadcast extends Logging {
194194
* If removeFromDriver is true, also remove these persisted blocks on the driver
195195
* and delete the associated broadcast file.
196196
*/
197-
def unpersist(id: Long, removeFromDriver: Boolean) = synchronized {
198-
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver)
197+
def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized {
198+
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
199199
if (removeFromDriver) {
200200
val file = getFile(id)
201201
files.remove(file.toString)

core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ class HttpBroadcastFactory extends BroadcastFactory {
3434

3535
/**
3636
* Remove all persisted state associated with the HTTP broadcast with the given ID.
37-
* @param removeFromDriver Whether to remove state from the driver.
37+
* @param removeFromDriver Whether to remove state from the driver
38+
* @param blocking Whether to block until unbroadcasted
3839
*/
39-
def unbroadcast(id: Long, removeFromDriver: Boolean) {
40-
HttpBroadcast.unpersist(id, removeFromDriver)
40+
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
41+
HttpBroadcast.unpersist(id, removeFromDriver, blocking)
4142
}
4243
}

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo
5353
/**
5454
* Remove all persisted state associated with this Torrent broadcast on the executors.
5555
*/
56-
def unpersist() {
57-
TorrentBroadcast.unpersist(id, removeFromDriver = false)
56+
def unpersist(blocking: Boolean) {
57+
TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking)
5858
}
5959

60-
protected def onDestroy() {
61-
TorrentBroadcast.unpersist(id, removeFromDriver = true)
60+
protected def onDestroy(blocking: Boolean) {
61+
TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
6262
}
6363

6464
private def sendBroadcast() {
@@ -242,8 +242,8 @@ private[spark] object TorrentBroadcast extends Logging {
242242
* Remove all persisted blocks associated with this torrent broadcast on the executors.
243243
* If removeFromDriver is true, also remove these persisted blocks on the driver.
244244
*/
245-
def unpersist(id: Long, removeFromDriver: Boolean) = synchronized {
246-
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver)
245+
def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized {
246+
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
247247
}
248248
}
249249

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@ class TorrentBroadcastFactory extends BroadcastFactory {
3636
/**
3737
* Remove all persisted state associated with the torrent broadcast with the given ID.
3838
* @param removeFromDriver Whether to remove state from the driver.
39+
* @param blocking Whether to block until unbroadcasted
3940
*/
40-
def unbroadcast(id: Long, removeFromDriver: Boolean) {
41-
TorrentBroadcast.unpersist(id, removeFromDriver)
41+
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
42+
TorrentBroadcast.unpersist(id, removeFromDriver, blocking)
4243
}
4344
}

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,12 +829,13 @@ private[spark] class BlockManager(
829829
/**
830830
* Remove all blocks belonging to the given broadcast.
831831
*/
832-
def removeBroadcast(broadcastId: Long, tellMaster: Boolean) {
832+
def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = {
833833
logInfo("Removing broadcast " + broadcastId)
834834
val blocksToRemove = blockInfo.keys.collect {
835835
case bid @ BroadcastBlockId(`broadcastId`, _) => bid
836836
}
837837
blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) }
838+
blocksToRemove.size
838839
}
839840

840841
/**

core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,28 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
117117
}
118118
}
119119

120-
/** Remove all blocks belonging to the given shuffle asynchronously. */
121-
def removeShuffle(shuffleId: Int) {
122-
askDriverWithReply(RemoveShuffle(shuffleId))
120+
/** Remove all blocks belonging to the given shuffle. */
121+
def removeShuffle(shuffleId: Int, blocking: Boolean) {
122+
val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
123+
future.onFailure {
124+
case e: Throwable => logError("Failed to remove shuffle " + shuffleId, e)
125+
}
126+
if (blocking) {
127+
Await.result(future, timeout)
128+
}
123129
}
124130

125-
/** Remove all blocks belonging to the given broadcast asynchronously. */
126-
def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean) {
127-
askDriverWithReply(RemoveBroadcast(broadcastId, removeFromMaster))
131+
/** Remove all blocks belonging to the given broadcast. */
132+
def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) {
133+
val future = askDriverWithReply[Future[Seq[Int]]](RemoveBroadcast(broadcastId, removeFromMaster))
134+
future.onFailure {
135+
case e: Throwable =>
136+
logError("Failed to remove broadcast " + broadcastId +
137+
" with removeFromMaster = " + removeFromMaster, e)
138+
}
139+
if (blocking) {
140+
Await.result(future, timeout)
141+
}
128142
}
129143

130144
/**

core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
100100
sender ! removeRdd(rddId)
101101

102102
case RemoveShuffle(shuffleId) =>
103-
removeShuffle(shuffleId)
104-
sender ! true
103+
sender ! removeShuffle(shuffleId)
105104

106105
case RemoveBroadcast(broadcastId, removeFromDriver) =>
107-
removeBroadcast(broadcastId, removeFromDriver)
108-
sender ! true
106+
sender ! removeBroadcast(broadcastId, removeFromDriver)
109107

110108
case RemoveBlock(blockId) =>
111109
removeBlockFromWorkers(blockId)
@@ -150,28 +148,41 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
150148
// The dispatcher is used as an implicit argument into the Future sequence construction.
151149
import context.dispatcher
152150
val removeMsg = RemoveRdd(rddId)
153-
Future.sequence(blockManagerInfo.values.map { bm =>
154-
bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
155-
}.toSeq)
151+
Future.sequence(
152+
blockManagerInfo.values.map { bm =>
153+
bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
154+
}.toSeq
155+
)
156156
}
157157

158-
private def removeShuffle(shuffleId: Int) {
158+
private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = {
159159
// Nothing to do in the BlockManagerMasterActor data structures
160+
import context.dispatcher
160161
val removeMsg = RemoveShuffle(shuffleId)
161-
blockManagerInfo.values.foreach { bm => bm.slaveActor ! removeMsg }
162+
Future.sequence(
163+
blockManagerInfo.values.map { bm =>
164+
bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean]
165+
}.toSeq
166+
)
162167
}
163168

164169
/**
165170
* Delegate RemoveBroadcast messages to each BlockManager because the master may not notified
166171
* of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed
167172
* from the executors, but not from the driver.
168173
*/
169-
private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) {
174+
private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = {
170175
// TODO: Consolidate usages of <driver>
176+
import context.dispatcher
171177
val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver)
172-
blockManagerInfo.values
173-
.filter { info => removeFromDriver || info.blockManagerId.executorId != "<driver>" }
174-
.foreach { bm => bm.slaveActor ! removeMsg }
178+
val requiredBlockManagers = blockManagerInfo.values.filter { info =>
179+
removeFromDriver || info.blockManagerId.executorId != "<driver>"
180+
}
181+
Future.sequence(
182+
requiredBlockManagers.map { bm =>
183+
bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
184+
}.toSeq
185+
)
175186
}
176187

177188
private def removeBlockManager(blockManagerId: BlockManagerId) {

core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.storage
1919

2020
import scala.concurrent.Future
2121

22-
import akka.actor.Actor
22+
import akka.actor.{ActorRef, Actor}
2323

2424
import org.apache.spark.{Logging, MapOutputTracker}
2525
import org.apache.spark.storage.BlockManagerMessages._
@@ -39,35 +39,44 @@ class BlockManagerSlaveActor(
3939
// Operations that involve removing blocks may be slow and should be done asynchronously
4040
override def receive = {
4141
case RemoveBlock(blockId) =>
42-
val removeBlock = Future { blockManager.removeBlock(blockId) }
43-
removeBlock.onFailure { case t: Throwable =>
44-
logError("Error in removing block " + blockId, t)
42+
doAsync("removing block", sender) {
43+
blockManager.removeBlock(blockId)
44+
true
4545
}
4646

4747
case RemoveRdd(rddId) =>
48-
val removeRdd = Future { sender ! blockManager.removeRdd(rddId) }
49-
removeRdd.onFailure { case t: Throwable =>
50-
logError("Error in removing RDD " + rddId, t)
48+
doAsync("removing RDD", sender) {
49+
blockManager.removeRdd(rddId)
5150
}
5251

5352
case RemoveShuffle(shuffleId) =>
54-
val removeShuffle = Future {
53+
doAsync("removing shuffle", sender) {
5554
blockManager.shuffleBlockManager.removeShuffle(shuffleId)
56-
if (mapOutputTracker != null) {
57-
mapOutputTracker.unregisterShuffle(shuffleId)
58-
}
59-
}
60-
removeShuffle.onFailure { case t: Throwable =>
61-
logError("Error in removing shuffle " + shuffleId, t)
6255
}
6356

6457
case RemoveBroadcast(broadcastId, tellMaster) =>
65-
val removeBroadcast = Future { blockManager.removeBroadcast(broadcastId, tellMaster) }
66-
removeBroadcast.onFailure { case t: Throwable =>
67-
logError("Error in removing broadcast " + broadcastId, t)
58+
doAsync("removing RDD", sender) {
59+
blockManager.removeBroadcast(broadcastId, tellMaster)
6860
}
6961

7062
case GetBlockStatus(blockId, _) =>
7163
sender ! blockManager.getStatus(blockId)
7264
}
65+
66+
private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) {
67+
val future = Future {
68+
logDebug(actionMessage)
69+
val response = body
70+
response
71+
}
72+
future.onSuccess { case response =>
73+
logDebug("Successful in " + actionMessage + ", response is " + response)
74+
responseActor ! response
75+
logDebug("Sent response: " + response + " to " + responseActor)
76+
}
77+
future.onFailure { case t: Throwable =>
78+
logError("Error in " + actionMessage, t)
79+
responseActor ! null.asInstanceOf[T]
80+
}
81+
}
7382
}

0 commit comments

Comments
 (0)