@@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl(
75
75
76
76
// TaskSetManagers are not thread safe, so any access to one should be synchronized
77
77
// on this class.
78
- val taskSetsByStageIdAndAttempt = new HashMap [Int , HashMap [Int , TaskSetManager ]]
78
+ private val taskSetsByStageIdAndAttempt = new HashMap [Int , HashMap [Int , TaskSetManager ]]
79
79
80
- val taskIdToStageIdAndAttempt = new HashMap [Long , ( Int , Int ) ]
80
+ private [scheduler] val taskIdToTaskSetManager = new HashMap [Long , TaskSetManager ]
81
81
val taskIdToExecutorId = new HashMap [Long , String ]
82
82
83
83
@ volatile private var hasReceivedTask = false
@@ -252,8 +252,7 @@ private[spark] class TaskSchedulerImpl(
252
252
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
253
253
tasks(i) += task
254
254
val tid = task.taskId
255
- taskIdToStageIdAndAttempt(tid) =
256
- (taskSet.taskSet.stageId, taskSet.taskSet.stageAttemptId)
255
+ taskIdToTaskSetManager(tid) = taskSet
257
256
taskIdToExecutorId(tid) = execId
258
257
executorsByHost(host) += execId
259
258
availableCpus(i) -= CPUS_PER_TASK
@@ -337,10 +336,10 @@ private[spark] class TaskSchedulerImpl(
337
336
failedExecutor = Some (execId)
338
337
}
339
338
}
340
- taskSetManagerForTask (tid) match {
339
+ taskIdToTaskSetManager.get (tid) match {
341
340
case Some (taskSet) =>
342
341
if (TaskState .isFinished(state)) {
343
- taskIdToStageIdAndAttempt .remove(tid)
342
+ taskIdToTaskSetManager .remove(tid)
344
343
taskIdToExecutorId.remove(tid)
345
344
}
346
345
if (state == TaskState .FINISHED ) {
@@ -379,12 +378,8 @@ private[spark] class TaskSchedulerImpl(
379
378
380
379
val metricsWithStageIds : Array [(Long , Int , Int , TaskMetrics )] = synchronized {
381
380
taskMetrics.flatMap { case (id, metrics) =>
382
- for {
383
- (stageId, stageAttemptId) <- taskIdToStageIdAndAttempt.get(id)
384
- attempts <- taskSetsByStageIdAndAttempt.get(stageId)
385
- taskSetMgr <- attempts.get(stageAttemptId)
386
- } yield {
387
- (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
381
+ taskIdToTaskSetManager.get(id).map { taskSetMgr =>
382
+ (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
388
383
}
389
384
}
390
385
}
@@ -543,12 +538,6 @@ private[spark] class TaskSchedulerImpl(
543
538
544
539
override def applicationAttemptId (): Option [String ] = backend.applicationAttemptId()
545
540
546
- private [scheduler] def taskSetManagerForTask (taskId : Long ): Option [TaskSetManager ] = {
547
- taskIdToStageIdAndAttempt.get(taskId).flatMap{ case (stageId, stageAttemptId) =>
548
- taskSetManagerForAttempt(stageId, stageAttemptId)
549
- }
550
- }
551
-
552
541
private [scheduler] def taskSetManagerForAttempt (
553
542
stageId : Int ,
554
543
stageAttemptId : Int ): Option [TaskSetManager ] = {
0 commit comments