@@ -21,7 +21,7 @@ import org.scalatest.FunSuite
21
21
22
22
import org .apache .spark .storage ._
23
23
import org .apache .spark .broadcast .HttpBroadcast
24
- import org .apache .spark .storage .{ BroadcastBlockId , BroadcastHelperBlockId }
24
+ import org .apache .spark .storage .BroadcastBlockId
25
25
26
26
class BroadcastSuite extends FunSuite with LocalSparkContext {
27
27
@@ -102,23 +102,22 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
102
102
* are present only on the expected nodes.
103
103
*/
104
104
private def testUnpersistHttpBroadcast (numSlaves : Int , removeFromDriver : Boolean ) {
105
- def getBlockIds (id : Long ) = Seq [BlockId ](BroadcastBlockId (id))
105
+ def getBlockIds (id : Long ) = Seq [BroadcastBlockId ](BroadcastBlockId (id))
106
106
107
107
// Verify that the broadcast file is created, and blocks are persisted only on the driver
108
- def afterCreation (blockIds : Seq [BlockId ], bmm : BlockManagerMaster ) {
108
+ def afterCreation (blockIds : Seq [BroadcastBlockId ], bmm : BlockManagerMaster ) {
109
109
assert(blockIds.size === 1 )
110
- val broadcastBlockId = blockIds.head.asInstanceOf [BroadcastBlockId ]
111
- val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0 )
110
+ val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0 )
112
111
assert(levels.size === 1 )
113
112
levels.head match { case (bm, level) =>
114
113
assert(bm.executorId === " <driver>" )
115
114
assert(level === StorageLevel .MEMORY_AND_DISK )
116
115
}
117
- assert(HttpBroadcast .getFile(broadcastBlockId .broadcastId).exists)
116
+ assert(HttpBroadcast .getFile(blockIds.head .broadcastId).exists)
118
117
}
119
118
120
119
// Verify that blocks are persisted in both the executors and the driver
121
- def afterUsingBroadcast (blockIds : Seq [BlockId ], bmm : BlockManagerMaster ) {
120
+ def afterUsingBroadcast (blockIds : Seq [BroadcastBlockId ], bmm : BlockManagerMaster ) {
122
121
assert(blockIds.size === 1 )
123
122
val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0 )
124
123
assert(levels.size === numSlaves + 1 )
@@ -129,12 +128,11 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
129
128
130
129
// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
131
130
// is true. In the latter case, also verify that the broadcast file is deleted on the driver.
132
- def afterUnpersist (blockIds : Seq [BlockId ], bmm : BlockManagerMaster ) {
131
+ def afterUnpersist (blockIds : Seq [BroadcastBlockId ], bmm : BlockManagerMaster ) {
133
132
assert(blockIds.size === 1 )
134
- val broadcastBlockId = blockIds.head.asInstanceOf [BroadcastBlockId ]
135
- val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0 )
133
+ val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0 )
136
134
assert(levels.size === (if (removeFromDriver) 0 else 1 ))
137
- assert(removeFromDriver === ! HttpBroadcast .getFile(broadcastBlockId .broadcastId).exists)
135
+ assert(removeFromDriver === ! HttpBroadcast .getFile(blockIds.head .broadcastId).exists)
138
136
}
139
137
140
138
testUnpersistBroadcast(numSlaves, httpConf, getBlockIds, afterCreation,
@@ -151,14 +149,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
151
149
private def testUnpersistTorrentBroadcast (numSlaves : Int , removeFromDriver : Boolean ) {
152
150
def getBlockIds (id : Long ) = {
153
151
val broadcastBlockId = BroadcastBlockId (id)
154
- val metaBlockId = BroadcastHelperBlockId (broadcastBlockId , " meta" )
152
+ val metaBlockId = BroadcastBlockId (id , " meta" )
155
153
// Assume broadcast value is small enough to fit into 1 piece
156
- val pieceBlockId = BroadcastHelperBlockId (broadcastBlockId , " piece0" )
157
- Seq [BlockId ](broadcastBlockId, metaBlockId, pieceBlockId)
154
+ val pieceBlockId = BroadcastBlockId (id , " piece0" )
155
+ Seq [BroadcastBlockId ](broadcastBlockId, metaBlockId, pieceBlockId)
158
156
}
159
157
160
158
// Verify that blocks are persisted only on the driver
161
- def afterCreation (blockIds : Seq [BlockId ], bmm : BlockManagerMaster ) {
159
+ def afterCreation (blockIds : Seq [BroadcastBlockId ], bmm : BlockManagerMaster ) {
162
160
blockIds.foreach { blockId =>
163
161
val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0 )
164
162
assert(levels.size === 1 )
@@ -170,27 +168,26 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
170
168
}
171
169
172
170
// Verify that blocks are persisted in both the executors and the driver
173
- def afterUsingBroadcast (blockIds : Seq [BlockId ], bmm : BlockManagerMaster ) {
171
+ def afterUsingBroadcast (blockIds : Seq [BroadcastBlockId ], bmm : BlockManagerMaster ) {
174
172
blockIds.foreach { blockId =>
175
173
val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0 )
176
- blockId match {
177
- case BroadcastHelperBlockId (_, " meta" ) =>
178
- // Meta data is only on the driver
179
- assert(levels.size === 1 )
180
- levels.head match { case (bm, _) => assert(bm.executorId === " <driver>" ) }
181
- case _ =>
182
- // Other blocks are on both the executors and the driver
183
- assert(levels.size === numSlaves + 1 )
184
- levels.foreach { case (_, level) =>
185
- assert(level === StorageLevel .MEMORY_AND_DISK )
186
- }
174
+ if (blockId.field == " meta" ) {
175
+ // Meta data is only on the driver
176
+ assert(levels.size === 1 )
177
+ levels.head match { case (bm, _) => assert(bm.executorId === " <driver>" ) }
178
+ } else {
179
+ // Other blocks are on both the executors and the driver
180
+ assert(levels.size === numSlaves + 1 )
181
+ levels.foreach { case (_, level) =>
182
+ assert(level === StorageLevel .MEMORY_AND_DISK )
183
+ }
187
184
}
188
185
}
189
186
}
190
187
191
188
// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
192
189
// is true.
193
- def afterUnpersist (blockIds : Seq [BlockId ], bmm : BlockManagerMaster ) {
190
+ def afterUnpersist (blockIds : Seq [BroadcastBlockId ], bmm : BlockManagerMaster ) {
194
191
val expectedNumBlocks = if (removeFromDriver) 0 else 1
195
192
var waitTimeMs = 1000L
196
193
blockIds.foreach { blockId =>
@@ -217,10 +214,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
217
214
private def testUnpersistBroadcast (
218
215
numSlaves : Int ,
219
216
broadcastConf : SparkConf ,
220
- getBlockIds : Long => Seq [BlockId ],
221
- afterCreation : (Seq [BlockId ], BlockManagerMaster ) => Unit ,
222
- afterUsingBroadcast : (Seq [BlockId ], BlockManagerMaster ) => Unit ,
223
- afterUnpersist : (Seq [BlockId ], BlockManagerMaster ) => Unit ,
217
+ getBlockIds : Long => Seq [BroadcastBlockId ],
218
+ afterCreation : (Seq [BroadcastBlockId ], BlockManagerMaster ) => Unit ,
219
+ afterUsingBroadcast : (Seq [BroadcastBlockId ], BlockManagerMaster ) => Unit ,
220
+ afterUnpersist : (Seq [BroadcastBlockId ], BlockManagerMaster ) => Unit ,
224
221
removeFromDriver : Boolean ) {
225
222
226
223
sc = new SparkContext (" local-cluster[%d, 1, 512]" .format(numSlaves), " test" , broadcastConf)
0 commit comments