Skip to content

Commit 12d64ba

Browse files
committed
Properly synchronize accesses to DAGScheduler cacheLocs map.
1 parent ee6e3ef commit 12d64ba

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,13 @@ class DAGScheduler(
9898

9999
private[scheduler] val activeJobs = new HashSet[ActiveJob]
100100

101-
// Contains the locations that each RDD's partitions are cached on
101+
/**
102+
* Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids
103+
* and its values are arrays indexed by partition numbers. Each array value is the set of
104+
* locations where that RDD partition is cached.
105+
*
106+
* All accesses to this map should be guarded by synchronizing on it (see SPARK-4454).
107+
*/
102108
private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]]
103109

104110
// For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with
@@ -183,18 +189,17 @@ class DAGScheduler(
183189
eventProcessLoop.post(TaskSetFailed(taskSet, reason))
184190
}
185191

186-
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
187-
if (!cacheLocs.contains(rdd.id)) {
192+
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = cacheLocs.synchronized {
193+
cacheLocs.getOrElseUpdate(rdd.id, {
188194
val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId]
189195
val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
190-
cacheLocs(rdd.id) = blockIds.map { id =>
196+
blockIds.map { id =>
191197
locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId))
192198
}
193-
}
194-
cacheLocs(rdd.id)
199+
})
195200
}
196201

197-
private def clearCacheLocs() {
202+
private def clearCacheLocs(): Unit = cacheLocs.synchronized {
198203
cacheLocs.clear()
199204
}
200205

@@ -1276,17 +1281,26 @@ class DAGScheduler(
12761281
}
12771282

12781283
/**
1279-
* Synchronized method that might be called from other threads.
1284+
* Gets the locality information associated with a partition of a particular RDD.
1285+
*
1286+
* This method is thread-safe and is called from both DAGScheduler and SparkContext.
1287+
*
12801288
* @param rdd whose partitions are to be looked at
12811289
* @param partition to lookup locality information for
12821290
* @return list of machines that are preferred by the partition
12831291
*/
12841292
private[spark]
1285-
def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized {
1293+
def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = {
12861294
getPreferredLocsInternal(rdd, partition, new HashSet)
12871295
}
12881296

1289-
/** Recursive implementation for getPreferredLocs. */
1297+
/**
1298+
* Recursive implementation for getPreferredLocs.
1299+
*
1300+
* This method is thread-safe because it only accesses DAGScheduler state through thread-safe
1301+
* methods (getCacheLocs()); please be careful when modifying this method, because any new
1302+
* DAGScheduler state accessed by it may require additional synchronization.
1303+
*/
12901304
private def getPreferredLocsInternal(
12911305
rdd: RDD[_],
12921306
partition: Int,

0 commit comments

Comments
 (0)