@@ -75,7 +75,7 @@ private[spark] class TaskSchedulerImpl(
75
75
76
76
// TaskSetManagers are not thread safe, so any access to one should be synchronized
77
77
// on this class.
78
- val taskSetsByStage = new HashMap [Int , HashMap [Int , TaskSetManager ]]
78
+ val taskSetsByStageIdAndAttempt = new HashMap [Int , HashMap [Int , TaskSetManager ]]
79
79
80
80
val taskIdToStageIdAndAttempt = new HashMap [Long , (Int , Int )]
81
81
val taskIdToExecutorId = new HashMap [Long , String ]
@@ -163,7 +163,8 @@ private[spark] class TaskSchedulerImpl(
163
163
this .synchronized {
164
164
val manager = createTaskSetManager(taskSet, maxTaskFailures)
165
165
val stage = taskSet.stageId
166
- val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap [Int , TaskSetManager ])
166
+ val stageTaskSets =
167
+ taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap [Int , TaskSetManager ])
167
168
stageTaskSets(taskSet.stageAttemptId) = manager
168
169
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
169
170
ts.taskSet != taskSet && ! ts.isZombie
@@ -201,7 +202,7 @@ private[spark] class TaskSchedulerImpl(
201
202
202
203
override def cancelTasks (stageId : Int , interruptThread : Boolean ): Unit = synchronized {
203
204
logInfo(" Cancelling stage " + stageId)
204
- taskSetsByStage .get(stageId).foreach { attempts =>
205
+ taskSetsByStageIdAndAttempt .get(stageId).foreach { attempts =>
205
206
attempts.foreach { case (_, tsm) =>
206
207
// There are two possible cases here:
207
208
// 1. The task set manager has been created and some tasks have been scheduled.
@@ -225,10 +226,10 @@ private[spark] class TaskSchedulerImpl(
225
226
* cleaned up.
226
227
*/
227
228
def taskSetFinished (manager : TaskSetManager ): Unit = synchronized {
228
- taskSetsByStage .get(manager.taskSet.stageId).foreach { taskSetsForStage =>
229
+ taskSetsByStageIdAndAttempt .get(manager.taskSet.stageId).foreach { taskSetsForStage =>
229
230
taskSetsForStage -= manager.taskSet.stageAttemptId
230
231
if (taskSetsForStage.isEmpty) {
231
- taskSetsByStage -= manager.taskSet.stageId
232
+ taskSetsByStageIdAndAttempt -= manager.taskSet.stageId
232
233
}
233
234
}
234
235
manager.parent.removeSchedulable(manager)
@@ -380,7 +381,7 @@ private[spark] class TaskSchedulerImpl(
380
381
taskMetrics.flatMap { case (id, metrics) =>
381
382
for {
382
383
(stageId, stageAttemptId) <- taskIdToStageIdAndAttempt.get(id)
383
- attempts <- taskSetsByStage .get(stageId)
384
+ attempts <- taskSetsByStageIdAndAttempt .get(stageId)
384
385
taskSetMgr <- attempts.get(stageAttemptId)
385
386
} yield {
386
387
(id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
@@ -416,10 +417,10 @@ private[spark] class TaskSchedulerImpl(
416
417
417
418
def error (message : String ) {
418
419
synchronized {
419
- if (taskSetsByStage .nonEmpty) {
420
+ if (taskSetsByStageIdAndAttempt .nonEmpty) {
420
421
// Have each task set throw a SparkException with the error
421
422
for {
422
- attempts <- taskSetsByStage .values
423
+ attempts <- taskSetsByStageIdAndAttempt .values
423
424
manager <- attempts.values
424
425
} {
425
426
try {
@@ -552,7 +553,7 @@ private[spark] class TaskSchedulerImpl(
552
553
stageId : Int ,
553
554
stageAttemptId : Int ): Option [TaskSetManager ] = {
554
555
for {
555
- attempts <- taskSetsByStage .get(stageId)
556
+ attempts <- taskSetsByStageIdAndAttempt .get(stageId)
556
557
manager <- attempts.get(stageAttemptId)
557
558
} yield {
558
559
manager
0 commit comments