@@ -75,9 +75,10 @@ private[spark] class TaskSchedulerImpl(
75
75
76
76
// TaskSetManagers are not thread safe, so any access to one should be synchronized
77
77
// on this class.
78
- val stageIdToActiveTaskSet = new HashMap [Int , TaskSetManager ]
78
+ val activeTaskSets = new HashMap [String , TaskSetManager ]
79
+ val taskSetsByStage = new HashMap [Int , HashMap [Int , TaskSetManager ]]
79
80
80
- val taskIdToStageId = new HashMap [Long , Int ]
81
+ val taskIdToTaskSetId = new HashMap [Long , String ]
81
82
val taskIdToExecutorId = new HashMap [Long , String ]
82
83
83
84
@ volatile private var hasReceivedTask = false
@@ -162,13 +163,17 @@ private[spark] class TaskSchedulerImpl(
162
163
logInfo(" Adding task set " + taskSet.id + " with " + tasks.length + " tasks" )
163
164
this .synchronized {
164
165
val manager = createTaskSetManager(taskSet, maxTaskFailures)
165
- stageIdToActiveTaskSet(taskSet.stageId) = manager
166
- val stageId = taskSet.stageId
167
- stageIdToActiveTaskSet.get(stageId).map { activeTaskSet =>
168
- throw new IllegalStateException (
169
- s " Active taskSet with id already exists for stage $stageId: ${activeTaskSet.taskSet.id}" )
166
+ activeTaskSets(taskSet.id) = manager
167
+ val stage = taskSet.stageId
168
+ val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap [Int , TaskSetManager ])
169
+ stageTaskSets(taskSet.attempt) = manager
170
+ val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
171
+ ts.taskSet != taskSet && ! ts.isZombie
172
+ }
173
+ if (conflictingTaskSet) {
174
+ throw new IllegalStateException (s " more than one active taskSet for stage $stage: " +
175
+ s " ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(" ," )}" )
170
176
}
171
- stageIdToActiveTaskSet(stageId) = manager
172
177
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
173
178
174
179
if (! isLocal && ! hasReceivedTask) {
@@ -198,7 +203,7 @@ private[spark] class TaskSchedulerImpl(
198
203
199
204
override def cancelTasks (stageId : Int , interruptThread : Boolean ): Unit = synchronized {
200
205
logInfo(" Cancelling stage " + stageId)
201
- stageIdToActiveTaskSet.get( stageId).map { tsm =>
206
+ activeTaskSets.find(_._2. stageId == stageId).foreach { case (_, tsm) =>
202
207
// There are two possible cases here:
203
208
// 1. The task set manager has been created and some tasks have been scheduled.
204
209
// In this case, send a kill signal to the executors to kill the task and then abort
@@ -220,7 +225,13 @@ private[spark] class TaskSchedulerImpl(
220
225
* cleaned up.
221
226
*/
222
227
def taskSetFinished (manager : TaskSetManager ): Unit = synchronized {
223
- stageIdToActiveTaskSet -= manager.stageId
228
+ activeTaskSets -= manager.taskSet.id
229
+ taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
230
+ taskSetsForStage -= manager.taskSet.attempt
231
+ if (taskSetsForStage.isEmpty) {
232
+ taskSetsByStage -= manager.taskSet.stageId
233
+ }
234
+ }
224
235
manager.parent.removeSchedulable(manager)
225
236
logInfo(" Removed TaskSet %s, whose tasks have all completed, from pool %s"
226
237
.format(manager.taskSet.id, manager.parent.name))
@@ -241,7 +252,7 @@ private[spark] class TaskSchedulerImpl(
241
252
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
242
253
tasks(i) += task
243
254
val tid = task.taskId
244
- taskIdToStageId (tid) = taskSet.taskSet.stageId
255
+ taskIdToTaskSetId (tid) = taskSet.taskSet.id
245
256
taskIdToExecutorId(tid) = execId
246
257
executorsByHost(host) += execId
247
258
availableCpus(i) -= CPUS_PER_TASK
@@ -325,13 +336,13 @@ private[spark] class TaskSchedulerImpl(
325
336
failedExecutor = Some (execId)
326
337
}
327
338
}
328
- taskIdToStageId .get(tid) match {
329
- case Some (stageId ) =>
339
+ taskIdToTaskSetId .get(tid) match {
340
+ case Some (taskSetId ) =>
330
341
if (TaskState .isFinished(state)) {
331
- taskIdToStageId .remove(tid)
342
+ taskIdToTaskSetId .remove(tid)
332
343
taskIdToExecutorId.remove(tid)
333
344
}
334
- stageIdToActiveTaskSet .get(stageId ).foreach { taskSet =>
345
+ activeTaskSets .get(taskSetId ).foreach { taskSet =>
335
346
if (state == TaskState .FINISHED ) {
336
347
taskSet.removeRunningTask(tid)
337
348
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
@@ -369,8 +380,8 @@ private[spark] class TaskSchedulerImpl(
369
380
370
381
val metricsWithStageIds : Array [(Long , Int , Int , TaskMetrics )] = synchronized {
371
382
taskMetrics.flatMap { case (id, metrics) =>
372
- taskIdToStageId .get(id)
373
- .flatMap(stageIdToActiveTaskSet .get)
383
+ taskIdToTaskSetId .get(id)
384
+ .flatMap(activeTaskSets .get)
374
385
.map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
375
386
}
376
387
}
@@ -403,9 +414,9 @@ private[spark] class TaskSchedulerImpl(
403
414
404
415
def error (message : String ) {
405
416
synchronized {
406
- if (stageIdToActiveTaskSet .nonEmpty) {
417
+ if (activeTaskSets .nonEmpty) {
407
418
// Have each task set throw a SparkException with the error
408
- for ((_ , manager) <- stageIdToActiveTaskSet ) {
419
+ for ((taskSetId , manager) <- activeTaskSets ) {
409
420
try {
410
421
manager.abort(message)
411
422
} catch {
0 commit comments