@@ -75,10 +75,9 @@ private[spark] class TaskSchedulerImpl(
75
75
76
76
// TaskSetManagers are not thread safe, so any access to one should be synchronized
77
77
// on this class.
78
- val activeTaskSets = new HashMap [String , TaskSetManager ]
79
78
val taskSetsByStage = new HashMap [Int , HashMap [Int , TaskSetManager ]]
80
79
81
- val taskIdToTaskSetId = new HashMap [Long , String ]
80
+ val taskIdToStageIdAndAttempt = new HashMap [Long , ( Int , Int ) ]
82
81
val taskIdToExecutorId = new HashMap [Long , String ]
83
82
84
83
@ volatile private var hasReceivedTask = false
@@ -163,10 +162,9 @@ private[spark] class TaskSchedulerImpl(
163
162
logInfo(" Adding task set " + taskSet.id + " with " + tasks.length + " tasks" )
164
163
this .synchronized {
165
164
val manager = createTaskSetManager(taskSet, maxTaskFailures)
166
- activeTaskSets(taskSet.id) = manager
167
165
val stage = taskSet.stageId
168
166
val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap [Int , TaskSetManager ])
169
- stageTaskSets(taskSet.attempt ) = manager
167
+ stageTaskSets(taskSet.stageAttemptId ) = manager
170
168
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
171
169
ts.taskSet != taskSet && ! ts.isZombie
172
170
}
@@ -203,19 +201,21 @@ private[spark] class TaskSchedulerImpl(
203
201
204
202
override def cancelTasks (stageId : Int , interruptThread : Boolean ): Unit = synchronized {
205
203
logInfo(" Cancelling stage " + stageId)
206
- activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
207
- // There are two possible cases here:
208
- // 1. The task set manager has been created and some tasks have been scheduled.
209
- // In this case, send a kill signal to the executors to kill the task and then abort
210
- // the stage.
211
- // 2. The task set manager has been created but no tasks has been scheduled. In this case,
212
- // simply abort the stage.
213
- tsm.runningTasksSet.foreach { tid =>
214
- val execId = taskIdToExecutorId(tid)
215
- backend.killTask(tid, execId, interruptThread)
204
+ taskSetsByStage.get(stageId).foreach { attempts =>
205
+ attempts.foreach { case (_, tsm) =>
206
+ // There are two possible cases here:
207
+ // 1. The task set manager has been created and some tasks have been scheduled.
208
+ // In this case, send a kill signal to the executors to kill the task and then abort
209
+ // the stage.
210
+ // 2. The task set manager has been created but no tasks has been scheduled. In this case,
211
+ // simply abort the stage.
212
+ tsm.runningTasksSet.foreach { tid =>
213
+ val execId = taskIdToExecutorId(tid)
214
+ backend.killTask(tid, execId, interruptThread)
215
+ }
216
+ tsm.abort(" Stage %s cancelled" .format(stageId))
217
+ logInfo(" Stage %d was cancelled" .format(stageId))
216
218
}
217
- tsm.abort(" Stage %s cancelled" .format(stageId))
218
- logInfo(" Stage %d was cancelled" .format(stageId))
219
219
}
220
220
}
221
221
@@ -225,9 +225,8 @@ private[spark] class TaskSchedulerImpl(
225
225
* cleaned up.
226
226
*/
227
227
def taskSetFinished (manager : TaskSetManager ): Unit = synchronized {
228
- activeTaskSets -= manager.taskSet.id
229
228
taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
230
- taskSetsForStage -= manager.taskSet.attempt
229
+ taskSetsForStage -= manager.taskSet.stageAttemptId
231
230
if (taskSetsForStage.isEmpty) {
232
231
taskSetsByStage -= manager.taskSet.stageId
233
232
}
@@ -252,7 +251,7 @@ private[spark] class TaskSchedulerImpl(
252
251
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
253
252
tasks(i) += task
254
253
val tid = task.taskId
255
- taskIdToTaskSetId (tid) = taskSet.taskSet.id
254
+ taskIdToStageIdAndAttempt (tid) = ( taskSet.taskSet.stageId, taskSet.taskSet.stageAttemptId)
256
255
taskIdToExecutorId(tid) = execId
257
256
executorsByHost(host) += execId
258
257
availableCpus(i) -= CPUS_PER_TASK
@@ -336,26 +335,24 @@ private[spark] class TaskSchedulerImpl(
336
335
failedExecutor = Some (execId)
337
336
}
338
337
}
339
- taskIdToTaskSetId.get (tid) match {
340
- case Some (taskSetId ) =>
338
+ taskSetManagerForTask (tid) match {
339
+ case Some (taskSet ) =>
341
340
if (TaskState .isFinished(state)) {
342
- taskIdToTaskSetId .remove(tid)
341
+ taskIdToStageIdAndAttempt .remove(tid)
343
342
taskIdToExecutorId.remove(tid)
344
343
}
345
- activeTaskSets.get(taskSetId).foreach { taskSet =>
346
- if (state == TaskState .FINISHED ) {
347
- taskSet.removeRunningTask(tid)
348
- taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
349
- } else if (Set (TaskState .FAILED , TaskState .KILLED , TaskState .LOST ).contains(state)) {
350
- taskSet.removeRunningTask(tid)
351
- taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
352
- }
344
+ if (state == TaskState .FINISHED ) {
345
+ taskSet.removeRunningTask(tid)
346
+ taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
347
+ } else if (Set (TaskState .FAILED , TaskState .KILLED , TaskState .LOST ).contains(state)) {
348
+ taskSet.removeRunningTask(tid)
349
+ taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
353
350
}
354
351
case None =>
355
352
logError(
356
353
(" Ignoring update with state %s for TID %s because its task set is gone (this is " +
357
- " likely the result of receiving duplicate task finished status updates)" )
358
- .format(state, tid))
354
+ " likely the result of receiving duplicate task finished status updates)" )
355
+ .format(state, tid))
359
356
}
360
357
} catch {
361
358
case e : Exception => logError(" Exception in statusUpdate" , e)
@@ -380,9 +377,13 @@ private[spark] class TaskSchedulerImpl(
380
377
381
378
val metricsWithStageIds : Array [(Long , Int , Int , TaskMetrics )] = synchronized {
382
379
taskMetrics.flatMap { case (id, metrics) =>
383
- taskIdToTaskSetId.get(id)
384
- .flatMap(activeTaskSets.get)
385
- .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
380
+ for {
381
+ (stageId, stageAttemptId) <- taskIdToStageIdAndAttempt.get(id)
382
+ attempts <- taskSetsByStage.get(stageId)
383
+ taskSetMgr <- attempts.get(stageAttemptId)
384
+ } yield {
385
+ (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
386
+ }
386
387
}
387
388
}
388
389
dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
@@ -414,9 +415,12 @@ private[spark] class TaskSchedulerImpl(
414
415
415
416
def error (message : String ) {
416
417
synchronized {
417
- if (activeTaskSets .nonEmpty) {
418
+ if (taskSetsByStage .nonEmpty) {
418
419
// Have each task set throw a SparkException with the error
419
- for ((taskSetId, manager) <- activeTaskSets) {
420
+ for {
421
+ attempts <- taskSetsByStage.values
422
+ manager <- attempts.values
423
+ } {
420
424
try {
421
425
manager.abort(message)
422
426
} catch {
@@ -537,6 +541,21 @@ private[spark] class TaskSchedulerImpl(
537
541
538
542
override def applicationAttemptId (): Option [String ] = backend.applicationAttemptId()
539
543
544
+ private [scheduler] def taskSetManagerForTask (taskId : Long ): Option [TaskSetManager ] = {
545
+ taskIdToStageIdAndAttempt.get(taskId).flatMap{ case (stageId, stageAttemptId) =>
546
+ taskSetManagerForAttempt(stageId, stageAttemptId)
547
+ }
548
+ }
549
+
550
+ private [scheduler] def taskSetManagerForAttempt (stageId : Int , stageAttemptId : Int ): Option [TaskSetManager ] = {
551
+ for {
552
+ attempts <- taskSetsByStage.get(stageId)
553
+ manager <- attempts.get(stageAttemptId)
554
+ } yield {
555
+ manager
556
+ }
557
+ }
558
+
540
559
}
541
560
542
561
0 commit comments