@@ -98,7 +98,13 @@ class DAGScheduler(
98
98
99
99
private [scheduler] val activeJobs = new HashSet [ActiveJob ]
100
100
101
- // Contains the locations that each RDD's partitions are cached on
101
+ /**
102
+ * Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids
103
+ * and its values are arrays indexed by partition numbers. Each array value is the set of
104
+ * locations where that RDD partition is cached.
105
+ *
106
+ * All accesses to this map should be guarded by synchronizing on it (see SPARK-4454).
107
+ */
102
108
private val cacheLocs = new HashMap [Int , Array [Seq [TaskLocation ]]]
103
109
104
110
// For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with
@@ -183,18 +189,17 @@ class DAGScheduler(
183
189
eventProcessLoop.post(TaskSetFailed (taskSet, reason))
184
190
}
185
191
186
- private def getCacheLocs (rdd : RDD [_]): Array [Seq [TaskLocation ]] = {
187
- if ( ! cacheLocs.contains (rdd.id)) {
192
+ private def getCacheLocs (rdd : RDD [_]): Array [Seq [TaskLocation ]] = cacheLocs. synchronized {
193
+ cacheLocs.getOrElseUpdate (rdd.id, {
188
194
val blockIds = rdd.partitions.indices.map(index => RDDBlockId (rdd.id, index)).toArray[BlockId ]
189
195
val locs = BlockManager .blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
190
- cacheLocs(rdd.id) = blockIds.map { id =>
196
+ blockIds.map { id =>
191
197
locs.getOrElse(id, Nil ).map(bm => TaskLocation (bm.host, bm.executorId))
192
198
}
193
- }
194
- cacheLocs(rdd.id)
199
+ })
195
200
}
196
201
197
- private def clearCacheLocs () {
202
+ private def clearCacheLocs (): Unit = cacheLocs. synchronized {
198
203
cacheLocs.clear()
199
204
}
200
205
@@ -1276,17 +1281,26 @@ class DAGScheduler(
1276
1281
}
1277
1282
1278
1283
/**
1279
- * Synchronized method that might be called from other threads.
1284
+ * Gets the locality information associated with a partition of a particular RDD.
1285
+ *
1286
+ * This method is thread-safe and is called from both DAGScheduler and SparkContext.
1287
+ *
1280
1288
* @param rdd whose partitions are to be looked at
1281
1289
* @param partition to lookup locality information for
1282
1290
* @return list of machines that are preferred by the partition
1283
1291
*/
1284
1292
private [spark]
1285
- def getPreferredLocs (rdd : RDD [_], partition : Int ): Seq [TaskLocation ] = synchronized {
1293
+ def getPreferredLocs (rdd : RDD [_], partition : Int ): Seq [TaskLocation ] = {
1286
1294
getPreferredLocsInternal(rdd, partition, new HashSet )
1287
1295
}
1288
1296
1289
- /** Recursive implementation for getPreferredLocs. */
1297
+ /**
1298
+ * Recursive implementation for getPreferredLocs.
1299
+ *
1300
+ * This method is thread-safe because it only accesses DAGScheduler state through thread-safe
1301
+ * methods (getCacheLocs()); please be careful when modifying this method, because any new
1302
+ * DAGScheduler state accessed by it may require additional synchronization.
1303
+ */
1290
1304
private def getPreferredLocsInternal (
1291
1305
rdd : RDD [_],
1292
1306
partition : Int ,
0 commit comments