Skip to content

Commit 584acd4

Browse files
committed
simplify going from taskId to taskSetMgr
1 parent e43ac25 commit 584acd4

File tree

3 files changed

+11
-22
lines changed

3 files changed

+11
-22
lines changed

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl(
7575

7676
// TaskSetManagers are not thread safe, so any access to one should be synchronized
7777
// on this class.
78-
val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]
78+
private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]
7979

80-
val taskIdToStageIdAndAttempt = new HashMap[Long, (Int, Int)]
80+
private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager]
8181
val taskIdToExecutorId = new HashMap[Long, String]
8282

8383
@volatile private var hasReceivedTask = false
@@ -252,8 +252,7 @@ private[spark] class TaskSchedulerImpl(
252252
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
253253
tasks(i) += task
254254
val tid = task.taskId
255-
taskIdToStageIdAndAttempt(tid) =
256-
(taskSet.taskSet.stageId, taskSet.taskSet.stageAttemptId)
255+
taskIdToTaskSetManager(tid) = taskSet
257256
taskIdToExecutorId(tid) = execId
258257
executorsByHost(host) += execId
259258
availableCpus(i) -= CPUS_PER_TASK
@@ -337,10 +336,10 @@ private[spark] class TaskSchedulerImpl(
337336
failedExecutor = Some(execId)
338337
}
339338
}
340-
taskSetManagerForTask(tid) match {
339+
taskIdToTaskSetManager.get(tid) match {
341340
case Some(taskSet) =>
342341
if (TaskState.isFinished(state)) {
343-
taskIdToStageIdAndAttempt.remove(tid)
342+
taskIdToTaskSetManager.remove(tid)
344343
taskIdToExecutorId.remove(tid)
345344
}
346345
if (state == TaskState.FINISHED) {
@@ -379,12 +378,8 @@ private[spark] class TaskSchedulerImpl(
379378

380379
val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
381380
taskMetrics.flatMap { case (id, metrics) =>
382-
for {
383-
(stageId, stageAttemptId) <- taskIdToStageIdAndAttempt.get(id)
384-
attempts <- taskSetsByStageIdAndAttempt.get(stageId)
385-
taskSetMgr <- attempts.get(stageAttemptId)
386-
} yield {
387-
(id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
381+
taskIdToTaskSetManager.get(id).map { taskSetMgr =>
382+
(id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
388383
}
389384
}
390385
}
@@ -543,12 +538,6 @@ private[spark] class TaskSchedulerImpl(
543538

544539
override def applicationAttemptId(): Option[String] = backend.applicationAttemptId()
545540

546-
private[scheduler] def taskSetManagerForTask(taskId: Long): Option[TaskSetManager] = {
547-
taskIdToStageIdAndAttempt.get(taskId).flatMap{ case (stageId, stageAttemptId) =>
548-
taskSetManagerForAttempt(stageId, stageAttemptId)
549-
}
550-
}
551-
552541
private[scheduler] def taskSetManagerForAttempt(
553542
stageId: Int,
554543
stageAttemptId: Int): Option[TaskSetManager] = {

core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
191191
for (task <- tasks.flatten) {
192192
val serializedTask = ser.serialize(task)
193193
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
194-
scheduler.taskSetManagerForTask(task.taskId).foreach { taskSet =>
194+
scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
195195
try {
196196
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
197197
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
198198
"spark.akka.frameSize or using broadcast variables for large values."
199199
msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
200200
AkkaUtils.reservedSizeBytes)
201-
taskSet.abort(msg)
201+
taskSetMgr.abort(msg)
202202
} catch {
203203
case e: Exception => logError("Exception in error callback", e)
204204
}

core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
188188
taskScheduler.submitTasks(attempt2)
189189
val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten
190190
assert(1 === taskDescriptions3.length)
191-
val mgr = taskScheduler.taskSetManagerForTask(taskDescriptions3(0).taskId).get
191+
val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId).get
192192
assert(mgr.taskSet.stageAttemptId === 1)
193193
}
194194

@@ -232,7 +232,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
232232
assert(10 === taskDescriptions3.length)
233233

234234
taskDescriptions3.foreach{ task =>
235-
val mgr = taskScheduler.taskSetManagerForTask(task.taskId).get
235+
val mgr = taskScheduler.taskIdToTaskSetManager.get(task.taskId).get
236236
assert(mgr.taskSet.stageAttemptId === 1)
237237
}
238238
}

0 commit comments

Comments
 (0)