@@ -33,19 +33,55 @@ extends Broadcast[T](id) with Logging with Serializable {
33
33
def value = value_
34
34
35
35
def unpersist (removeSource : Boolean ) {
36
- SparkEnv .get.blockManager.master.removeBlock(broadcastId)
37
- SparkEnv .get.blockManager.removeBlock(broadcastId)
36
+ TorrentBroadcast .synchronized {
37
+ SparkEnv .get.blockManager.master.removeBlock(broadcastId)
38
+ SparkEnv .get.blockManager.removeBlock(broadcastId)
39
+ }
40
+
41
+ if (! removeSource) {
42
+ // We can't tell BlockManager master to remove blocks from all nodes except driver,
43
+ // so we need to save them here in order to store them on disk later.
44
+ // This may be inefficient if blocks were already dropped to disk,
45
+ // but since unpersist is supposed to be called right after working with
46
+ // a broadcast this should not happen (and getting them from memory is cheap).
47
+ arrayOfBlocks = new Array [TorrentBlock ](totalBlocks)
48
+
49
+ for (pid <- 0 until totalBlocks) {
50
+ val pieceId = pieceBlockId(pid)
51
+ TorrentBroadcast .synchronized {
52
+ SparkEnv .get.blockManager.getSingle(pieceId) match {
53
+ case Some (x) =>
54
+ arrayOfBlocks(pid) = x.asInstanceOf [TorrentBlock ]
55
+ case None =>
56
+ throw new SparkException (" Failed to get " + pieceId + " of " + broadcastId)
57
+ }
58
+ }
59
+ }
60
+ }
61
+
62
+ for (pid <- 0 until totalBlocks) {
63
+ TorrentBroadcast .synchronized {
64
+ SparkEnv .get.blockManager.master.removeBlock(pieceBlockId(pid))
65
+ }
66
+ }
38
67
39
68
if (removeSource) {
40
- for (pid <- pieceIds) {
41
- SparkEnv .get.blockManager.removeBlock(pieceBlockId(pid) )
69
+ TorrentBroadcast . synchronized {
70
+ SparkEnv .get.blockManager.removeBlock(metaId )
42
71
}
43
- SparkEnv .get.blockManager.removeBlock(metaId)
44
72
} else {
45
- for (pid <- pieceIds) {
46
- SparkEnv .get.blockManager.dropFromMemory(pieceBlockId(pid) )
73
+ TorrentBroadcast . synchronized {
74
+ SparkEnv .get.blockManager.dropFromMemory(metaId )
47
75
}
48
- SparkEnv .get.blockManager.dropFromMemory(metaId)
76
+
77
+ for (i <- 0 until totalBlocks) {
78
+ val pieceId = pieceBlockId(i)
79
+ TorrentBroadcast .synchronized {
80
+ SparkEnv .get.blockManager.putSingle(
81
+ pieceId, arrayOfBlocks(i), StorageLevel .DISK_ONLY , true )
82
+ }
83
+ }
84
+ arrayOfBlocks = null
49
85
}
50
86
}
51
87
@@ -128,11 +164,6 @@ extends Broadcast[T](id) with Logging with Serializable {
128
164
}
129
165
130
166
private def resetWorkerVariables () {
131
- if (arrayOfBlocks != null ) {
132
- for (pid <- pieceIds) {
133
- SparkEnv .get.blockManager.removeBlock(pieceBlockId(pid))
134
- }
135
- }
136
167
arrayOfBlocks = null
137
168
totalBytes = - 1
138
169
totalBlocks = - 1
0 commit comments