Skip to content

Commit d7f1ef2

Browse files
committed
get rid of activeTaskSets
1 parent a21c8b5 commit d7f1ef2

File tree

3 files changed

+59
-41
lines changed

3 files changed

+59
-41
lines changed

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 56 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,9 @@ private[spark] class TaskSchedulerImpl(
7575

7676
// TaskSetManagers are not thread safe, so any access to one should be synchronized
7777
// on this class.
78-
val activeTaskSets = new HashMap[String, TaskSetManager]
7978
val taskSetsByStage = new HashMap[Int, HashMap[Int, TaskSetManager]]
8079

81-
val taskIdToTaskSetId = new HashMap[Long, String]
80+
val taskIdToStageIdAndAttempt = new HashMap[Long, (Int, Int)]
8281
val taskIdToExecutorId = new HashMap[Long, String]
8382

8483
@volatile private var hasReceivedTask = false
@@ -163,10 +162,9 @@ private[spark] class TaskSchedulerImpl(
163162
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
164163
this.synchronized {
165164
val manager = createTaskSetManager(taskSet, maxTaskFailures)
166-
activeTaskSets(taskSet.id) = manager
167165
val stage = taskSet.stageId
168166
val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
169-
stageTaskSets(taskSet.attempt) = manager
167+
stageTaskSets(taskSet.stageAttemptId) = manager
170168
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
171169
ts.taskSet != taskSet && !ts.isZombie
172170
}
@@ -203,19 +201,21 @@ private[spark] class TaskSchedulerImpl(
203201

204202
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
205203
logInfo("Cancelling stage " + stageId)
206-
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
207-
// There are two possible cases here:
208-
// 1. The task set manager has been created and some tasks have been scheduled.
209-
// In this case, send a kill signal to the executors to kill the task and then abort
210-
// the stage.
211-
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
212-
// simply abort the stage.
213-
tsm.runningTasksSet.foreach { tid =>
214-
val execId = taskIdToExecutorId(tid)
215-
backend.killTask(tid, execId, interruptThread)
204+
taskSetsByStage.get(stageId).foreach { attempts =>
205+
attempts.foreach { case (_, tsm) =>
206+
// There are two possible cases here:
207+
// 1. The task set manager has been created and some tasks have been scheduled.
208+
// In this case, send a kill signal to the executors to kill the task and then abort
209+
// the stage.
210+
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
211+
// simply abort the stage.
212+
tsm.runningTasksSet.foreach { tid =>
213+
val execId = taskIdToExecutorId(tid)
214+
backend.killTask(tid, execId, interruptThread)
215+
}
216+
tsm.abort("Stage %s cancelled".format(stageId))
217+
logInfo("Stage %d was cancelled".format(stageId))
216218
}
217-
tsm.abort("Stage %s cancelled".format(stageId))
218-
logInfo("Stage %d was cancelled".format(stageId))
219219
}
220220
}
221221

@@ -225,9 +225,8 @@ private[spark] class TaskSchedulerImpl(
225225
* cleaned up.
226226
*/
227227
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
228-
activeTaskSets -= manager.taskSet.id
229228
taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
230-
taskSetsForStage -= manager.taskSet.attempt
229+
taskSetsForStage -= manager.taskSet.stageAttemptId
231230
if (taskSetsForStage.isEmpty) {
232231
taskSetsByStage -= manager.taskSet.stageId
233232
}
@@ -252,7 +251,7 @@ private[spark] class TaskSchedulerImpl(
252251
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
253252
tasks(i) += task
254253
val tid = task.taskId
255-
taskIdToTaskSetId(tid) = taskSet.taskSet.id
254+
taskIdToStageIdAndAttempt(tid) = (taskSet.taskSet.stageId, taskSet.taskSet.stageAttemptId)
256255
taskIdToExecutorId(tid) = execId
257256
executorsByHost(host) += execId
258257
availableCpus(i) -= CPUS_PER_TASK
@@ -336,26 +335,24 @@ private[spark] class TaskSchedulerImpl(
336335
failedExecutor = Some(execId)
337336
}
338337
}
339-
taskIdToTaskSetId.get(tid) match {
340-
case Some(taskSetId) =>
338+
taskSetManagerForTask(tid) match {
339+
case Some(taskSet) =>
341340
if (TaskState.isFinished(state)) {
342-
taskIdToTaskSetId.remove(tid)
341+
taskIdToStageIdAndAttempt.remove(tid)
343342
taskIdToExecutorId.remove(tid)
344343
}
345-
activeTaskSets.get(taskSetId).foreach { taskSet =>
346-
if (state == TaskState.FINISHED) {
347-
taskSet.removeRunningTask(tid)
348-
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
349-
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
350-
taskSet.removeRunningTask(tid)
351-
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
352-
}
344+
if (state == TaskState.FINISHED) {
345+
taskSet.removeRunningTask(tid)
346+
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
347+
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
348+
taskSet.removeRunningTask(tid)
349+
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
353350
}
354351
case None =>
355352
logError(
356353
("Ignoring update with state %s for TID %s because its task set is gone (this is " +
357-
"likely the result of receiving duplicate task finished status updates)")
358-
.format(state, tid))
354+
"likely the result of receiving duplicate task finished status updates)")
355+
.format(state, tid))
359356
}
360357
} catch {
361358
case e: Exception => logError("Exception in statusUpdate", e)
@@ -380,9 +377,13 @@ private[spark] class TaskSchedulerImpl(
380377

381378
val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
382379
taskMetrics.flatMap { case (id, metrics) =>
383-
taskIdToTaskSetId.get(id)
384-
.flatMap(activeTaskSets.get)
385-
.map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
380+
for {
381+
(stageId, stageAttemptId) <- taskIdToStageIdAndAttempt.get(id)
382+
attempts <- taskSetsByStage.get(stageId)
383+
taskSetMgr <- attempts.get(stageAttemptId)
384+
} yield {
385+
(id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
386+
}
386387
}
387388
}
388389
dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
@@ -414,9 +415,12 @@ private[spark] class TaskSchedulerImpl(
414415

415416
def error(message: String) {
416417
synchronized {
417-
if (activeTaskSets.nonEmpty) {
418+
if (taskSetsByStage.nonEmpty) {
418419
// Have each task set throw a SparkException with the error
419-
for ((taskSetId, manager) <- activeTaskSets) {
420+
for {
421+
attempts <- taskSetsByStage.values
422+
manager <- attempts.values
423+
} {
420424
try {
421425
manager.abort(message)
422426
} catch {
@@ -537,6 +541,21 @@ private[spark] class TaskSchedulerImpl(
537541

538542
override def applicationAttemptId(): Option[String] = backend.applicationAttemptId()
539543

544+
private[scheduler] def taskSetManagerForTask(taskId: Long): Option[TaskSetManager] = {
545+
taskIdToStageIdAndAttempt.get(taskId).flatMap{ case (stageId, stageAttemptId) =>
546+
taskSetManagerForAttempt(stageId, stageAttemptId)
547+
}
548+
}
549+
550+
private[scheduler] def taskSetManagerForAttempt(stageId: Int, stageAttemptId: Int): Option[TaskSetManager] = {
551+
for {
552+
attempts <- taskSetsByStage.get(stageId)
553+
manager <- attempts.get(stageAttemptId)
554+
} yield {
555+
manager
556+
}
557+
}
558+
540559
}
541560

542561

core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ import java.util.Properties
2626
private[spark] class TaskSet(
2727
val tasks: Array[Task[_]],
2828
val stageId: Int,
29-
val attempt: Int,
29+
val stageAttemptId: Int,
3030
val priority: Int,
3131
val properties: Properties) {
32-
val id: String = stageId + "." + attempt
32+
val id: String = stageId + "." + stageAttemptId
3333

3434
override def toString: String = "TaskSet " + id
3535
}

core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
191191
for (task <- tasks.flatten) {
192192
val serializedTask = ser.serialize(task)
193193
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
194-
val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
195-
scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
194+
scheduler.taskSetManagerForTask(task.taskId).foreach { taskSet =>
196195
try {
197196
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
198197
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +

0 commit comments

Comments
 (0)