Skip to content

Commit c0d4d90

Browse files
committed
Revert "Index active task sets by stage Id rather than by task set id"
This reverts commit baf46e1.
1 parent f025154 commit c0d4d90

File tree

3 files changed

+34
-23
lines changed

3 files changed

+34
-23
lines changed

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,10 @@ private[spark] class TaskSchedulerImpl(
7575

7676
// TaskSetManagers are not thread safe, so any access to one should be synchronized
7777
// on this class.
78-
val stageIdToActiveTaskSet = new HashMap[Int, TaskSetManager]
78+
val activeTaskSets = new HashMap[String, TaskSetManager]
79+
val taskSetsByStage = new HashMap[Int, HashMap[Int, TaskSetManager]]
7980

80-
val taskIdToStageId = new HashMap[Long, Int]
81+
val taskIdToTaskSetId = new HashMap[Long, String]
8182
val taskIdToExecutorId = new HashMap[Long, String]
8283

8384
@volatile private var hasReceivedTask = false
@@ -162,13 +163,17 @@ private[spark] class TaskSchedulerImpl(
162163
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
163164
this.synchronized {
164165
val manager = createTaskSetManager(taskSet, maxTaskFailures)
165-
stageIdToActiveTaskSet(taskSet.stageId) = manager
166-
val stageId = taskSet.stageId
167-
stageIdToActiveTaskSet.get(stageId).map { activeTaskSet =>
168-
throw new IllegalStateException(
169-
s"Active taskSet with id already exists for stage $stageId: ${activeTaskSet.taskSet.id}")
166+
activeTaskSets(taskSet.id) = manager
167+
val stage = taskSet.stageId
168+
val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
169+
stageTaskSets(taskSet.attempt) = manager
170+
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
171+
ts.taskSet != taskSet && !ts.isZombie
172+
}
173+
if (conflictingTaskSet) {
174+
throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
175+
s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
170176
}
171-
stageIdToActiveTaskSet(stageId) = manager
172177
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
173178

174179
if (!isLocal && !hasReceivedTask) {
@@ -198,7 +203,7 @@ private[spark] class TaskSchedulerImpl(
198203

199204
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
200205
logInfo("Cancelling stage " + stageId)
201-
stageIdToActiveTaskSet.get(stageId).map {tsm =>
206+
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
202207
// There are two possible cases here:
203208
// 1. The task set manager has been created and some tasks have been scheduled.
204209
// In this case, send a kill signal to the executors to kill the task and then abort
@@ -220,7 +225,13 @@ private[spark] class TaskSchedulerImpl(
220225
* cleaned up.
221226
*/
222227
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
223-
stageIdToActiveTaskSet -= manager.stageId
228+
activeTaskSets -= manager.taskSet.id
229+
taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
230+
taskSetsForStage -= manager.taskSet.attempt
231+
if (taskSetsForStage.isEmpty) {
232+
taskSetsByStage -= manager.taskSet.stageId
233+
}
234+
}
224235
manager.parent.removeSchedulable(manager)
225236
logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
226237
.format(manager.taskSet.id, manager.parent.name))
@@ -241,7 +252,7 @@ private[spark] class TaskSchedulerImpl(
241252
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
242253
tasks(i) += task
243254
val tid = task.taskId
244-
taskIdToStageId(tid) = taskSet.taskSet.stageId
255+
taskIdToTaskSetId(tid) = taskSet.taskSet.id
245256
taskIdToExecutorId(tid) = execId
246257
executorsByHost(host) += execId
247258
availableCpus(i) -= CPUS_PER_TASK
@@ -325,13 +336,13 @@ private[spark] class TaskSchedulerImpl(
325336
failedExecutor = Some(execId)
326337
}
327338
}
328-
taskIdToStageId.get(tid) match {
329-
case Some(stageId) =>
339+
taskIdToTaskSetId.get(tid) match {
340+
case Some(taskSetId) =>
330341
if (TaskState.isFinished(state)) {
331-
taskIdToStageId.remove(tid)
342+
taskIdToTaskSetId.remove(tid)
332343
taskIdToExecutorId.remove(tid)
333344
}
334-
stageIdToActiveTaskSet.get(stageId).foreach { taskSet =>
345+
activeTaskSets.get(taskSetId).foreach { taskSet =>
335346
if (state == TaskState.FINISHED) {
336347
taskSet.removeRunningTask(tid)
337348
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
@@ -369,8 +380,8 @@ private[spark] class TaskSchedulerImpl(
369380

370381
val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
371382
taskMetrics.flatMap { case (id, metrics) =>
372-
taskIdToStageId.get(id)
373-
.flatMap(stageIdToActiveTaskSet.get)
383+
taskIdToTaskSetId.get(id)
384+
.flatMap(activeTaskSets.get)
374385
.map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
375386
}
376387
}
@@ -403,9 +414,9 @@ private[spark] class TaskSchedulerImpl(
403414

404415
def error(message: String) {
405416
synchronized {
406-
if (stageIdToActiveTaskSet.nonEmpty) {
417+
if (activeTaskSets.nonEmpty) {
407418
// Have each task set throw a SparkException with the error
408-
for ((_, manager) <- stageIdToActiveTaskSet) {
419+
for ((taskSetId, manager) <- activeTaskSets) {
409420
try {
410421
manager.abort(message)
411422
} catch {

core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
191191
for (task <- tasks.flatten) {
192192
val serializedTask = ser.serialize(task)
193193
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
194-
val taskSetId = scheduler.taskIdToStageId(task.taskId)
195-
scheduler.stageIdToActiveTaskSet.get(taskSetId).foreach { taskSet =>
194+
val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
195+
scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
196196
try {
197197
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
198198
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +

core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,11 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
144144
intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) }
145145

146146
// OK to submit multiple if previous attempts are all zombie
147-
taskScheduler.stageIdToActiveTaskSet(attempt1.stageId).isZombie = true
147+
taskScheduler.activeTaskSets(attempt1.id).isZombie = true
148148
taskScheduler.submitTasks(attempt2)
149149
val attempt3 = new TaskSet(Array(new FakeTask(0)), 0, 2, 0, null)
150150
intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) }
151-
taskScheduler.stageIdToActiveTaskSet(attempt2.stageId).isZombie = true
151+
taskScheduler.activeTaskSets(attempt2.id).isZombie = true
152152
taskScheduler.submitTasks(attempt3)
153153
}
154154

0 commit comments

Comments
 (0)