@@ -76,6 +76,7 @@ private[spark] class TaskSchedulerImpl(
76
76
// TaskSetManagers are not thread safe, so any access to one should be synchronized
77
77
// on this class.
78
78
val activeTaskSets = new HashMap [String , TaskSetManager ]
79
+ val taskSetsByStage = new HashMap [Int , HashMap [Int , TaskSetManager ]]
79
80
80
81
val taskIdToTaskSetId = new HashMap [Long , String ]
81
82
val taskIdToExecutorId = new HashMap [Long , String ]
@@ -164,13 +165,14 @@ private[spark] class TaskSchedulerImpl(
164
165
val manager = createTaskSetManager(taskSet, maxTaskFailures)
165
166
activeTaskSets(taskSet.id) = manager
166
167
val stage = taskSet.stageId
167
- val conflictingTaskSet = activeTaskSets.exists { case (id, ts) =>
168
- // if the id matches, it really should be the same taskSet, but in some unit tests
169
- // we add new taskSets with the same id
170
- id != taskSet.id && ! ts.isZombie && ts.stageId == stage
168
+ val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap [ Int , TaskSetManager ])
169
+ stageTaskSets( taskSet.attempt) = manager
170
+ val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
171
+ ts.taskSet != taskSet && ! ts.isZombie
171
172
}
172
173
if (conflictingTaskSet) {
173
- throw new SparkIllegalStateException (s " more than one active taskSet for stage $stage" )
174
+ throw new SparkIllegalStateException (s " more than one active taskSet for stage $stage: " +
175
+ s " ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(" ," )}" )
174
176
}
175
177
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
176
178
@@ -224,6 +226,12 @@ private[spark] class TaskSchedulerImpl(
224
226
*/
225
227
def taskSetFinished (manager : TaskSetManager ): Unit = synchronized {
226
228
activeTaskSets -= manager.taskSet.id
229
+ taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
230
+ taskSetsForStage -= manager.taskSet.attempt
231
+ if (taskSetsForStage.isEmpty) {
232
+ taskSetsByStage -= manager.taskSet.stageId
233
+ }
234
+ }
227
235
manager.parent.removeSchedulable(manager)
228
236
logInfo(" Removed TaskSet %s, whose tasks have all completed, from pool %s"
229
237
.format(manager.taskSet.id, manager.parent.name))
0 commit comments