Skip to content

Commit 8509284

Browse files
committed
[SPARK-23433][CORE] Late zombie task completions update all tasksets
Fetch failure lead to multiple tasksets which are active for a given stage. While there is only one "active" version of the taskset, the earlier attempts can still have running tasks, which can complete successfully. So a task completion needs to update every taskset so that it knows the partition is completed. That way the final active taskset does not try to submit another task for the same partition, and so that it knows when it is completed and when it should be marked as a "zombie". Added a regression test. Author: Imran Rashid <irashid@cloudera.com> Closes #21131 from squito/SPARK-23433. (cherry picked from commit 94641fe) Signed-off-by: Imran Rashid <irashid@cloudera.com>
1 parent 61e7bc0 commit 8509284

File tree

3 files changed

+137
-1
lines changed

3 files changed

+137
-1
lines changed

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,20 @@ private[spark] class TaskSchedulerImpl(
689689
}
690690
}
691691

692+
/**
693+
* Marks the task has completed in all TaskSetManagers for the given stage.
694+
*
695+
* After stage failure and retry, there may be multiple TaskSetManagers for the stage.
696+
* If an earlier attempt of a stage completes a task, we should ensure that the later attempts
697+
* do not also submit those same tasks. That also means that a task completion from an earlier
698+
* attempt can lead to the entire stage getting marked as successful.
699+
*/
700+
private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = {
701+
taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm =>
702+
tsm.markPartitionCompleted(partitionId)
703+
}
704+
}
705+
692706
}
693707

694708

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ private[spark] class TaskSetManager(
7474
val ser = env.closureSerializer.newInstance()
7575

7676
val tasks = taskSet.tasks
77+
private[scheduler] val partitionToIndex = tasks.zipWithIndex
78+
.map { case (t, idx) => t.partitionId -> idx }.toMap
7779
val numTasks = tasks.length
7880
val copiesRunning = new Array[Int](numTasks)
7981

@@ -154,7 +156,7 @@ private[spark] class TaskSetManager(
154156
private[scheduler] val speculatableTasks = new HashSet[Int]
155157

156158
// Task index, start and finish time for each task attempt (indexed by task ID)
157-
private val taskInfos = new HashMap[Long, TaskInfo]
159+
private[scheduler] val taskInfos = new HashMap[Long, TaskInfo]
158160

159161
// Use a MedianHeap to record durations of successful tasks so we know when to launch
160162
// speculative tasks. This is only used when speculation is enabled, to avoid the overhead
@@ -755,6 +757,9 @@ private[spark] class TaskSetManager(
755757
logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id +
756758
" because task " + index + " has already completed successfully")
757759
}
760+
// There may be multiple tasksets for this stage -- we let all of them know that the partition
761+
// was completed. This may result in some of the tasksets getting completed.
762+
sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId)
758763
// This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the
759764
// "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not
760765
// "deserialize" the value when holding a lock to avoid blocking other threads. So we call
@@ -765,6 +770,19 @@ private[spark] class TaskSetManager(
765770
maybeFinishTaskSet()
766771
}
767772

773+
private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = {
774+
partitionToIndex.get(partitionId).foreach { index =>
775+
if (!successful(index)) {
776+
tasksSuccessful += 1
777+
successful(index) = true
778+
if (tasksSuccessful == numTasks) {
779+
isZombie = true
780+
}
781+
maybeFinishTaskSet()
782+
}
783+
}
784+
}
785+
768786
/**
769787
* Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
770788
* DAG Scheduler.

core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -917,4 +917,108 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
917917
taskScheduler.initialize(new FakeSchedulerBackend)
918918
}
919919
}
920+
921+
test("Completions in zombie tasksets update status of non-zombie taskset") {
922+
val taskScheduler = setupSchedulerWithMockTaskSetBlacklist()
923+
val valueSer = SparkEnv.get.serializer.newInstance()
924+
925+
def completeTaskSuccessfully(tsm: TaskSetManager, partition: Int): Unit = {
926+
val indexInTsm = tsm.partitionToIndex(partition)
927+
val matchingTaskInfo = tsm.taskAttempts.flatten.filter(_.index == indexInTsm).head
928+
val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq())
929+
tsm.handleSuccessfulTask(matchingTaskInfo.taskId, result)
930+
}
931+
932+
// Submit a task set, have it fail with a fetch failed, and then re-submit the task attempt,
933+
// two times, so we have three active task sets for one stage. (For this to really happen,
934+
// you'd need the previous stage to also get restarted, and then succeed, in between each
935+
// attempt, but that happens outside what we're mocking here.)
936+
val zombieAttempts = (0 until 2).map { stageAttempt =>
937+
val attempt = FakeTask.createTaskSet(10, stageAttemptId = stageAttempt)
938+
taskScheduler.submitTasks(attempt)
939+
val tsm = taskScheduler.taskSetManagerForAttempt(0, stageAttempt).get
940+
val offers = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
941+
taskScheduler.resourceOffers(offers)
942+
assert(tsm.runningTasks === 10)
943+
// fail attempt
944+
tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED,
945+
FetchFailed(null, 0, 0, 0, "fetch failed"))
946+
// the attempt is a zombie, but the tasks are still running (this could be true even if
947+
// we actively killed those tasks, as killing is best-effort)
948+
assert(tsm.isZombie)
949+
assert(tsm.runningTasks === 9)
950+
tsm
951+
}
952+
953+
// we've now got 2 zombie attempts, each with 9 tasks still active. Submit the 3rd attempt for
954+
// the stage, but this time with insufficient resources so not all tasks are active.
955+
956+
val finalAttempt = FakeTask.createTaskSet(10, stageAttemptId = 2)
957+
taskScheduler.submitTasks(finalAttempt)
958+
val finalTsm = taskScheduler.taskSetManagerForAttempt(0, 2).get
959+
val offers = (0 until 5).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
960+
val finalAttemptLaunchedPartitions = taskScheduler.resourceOffers(offers).flatten.map { task =>
961+
finalAttempt.tasks(task.index).partitionId
962+
}.toSet
963+
assert(finalTsm.runningTasks === 5)
964+
assert(!finalTsm.isZombie)
965+
966+
// We simulate late completions from our zombie tasksets, corresponding to all the pending
967+
// partitions in our final attempt. This means we're only waiting on the tasks we've already
968+
// launched.
969+
val finalAttemptPendingPartitions = (0 until 10).toSet.diff(finalAttemptLaunchedPartitions)
970+
finalAttemptPendingPartitions.foreach { partition =>
971+
completeTaskSuccessfully(zombieAttempts(0), partition)
972+
}
973+
974+
// If there is another resource offer, we shouldn't run anything. Though our final attempt
975+
// used to have pending tasks, now those tasks have been completed by zombie attempts. The
976+
// remaining tasks to compute are already active in the non-zombie attempt.
977+
assert(
978+
taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec-1", "host-1", 1))).flatten.isEmpty)
979+
980+
val remainingTasks = finalAttemptLaunchedPartitions.toIndexedSeq.sorted
981+
982+
// finally, if we finish the remaining partitions from a mix of tasksets, all attempts should be
983+
// marked as zombie.
984+
// for each of the remaining tasks, find the tasksets with an active copy of the task, and
985+
// finish the task.
986+
remainingTasks.foreach { partition =>
987+
val tsm = if (partition == 0) {
988+
// we failed this task on both zombie attempts, this one is only present in the latest
989+
// taskset
990+
finalTsm
991+
} else {
992+
// should be active in every taskset. We choose a zombie taskset just to make sure that
993+
// we transition the active taskset correctly even if the final completion comes
994+
// from a zombie.
995+
zombieAttempts(partition % 2)
996+
}
997+
completeTaskSuccessfully(tsm, partition)
998+
}
999+
1000+
assert(finalTsm.isZombie)
1001+
1002+
// no taskset has completed all of its tasks, so no updates to the blacklist tracker yet
1003+
verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), anyObject())
1004+
1005+
// finally, lets complete all the tasks. We simulate failures in attempt 1, but everything
1006+
// else succeeds, to make sure we get the right updates to the blacklist in all cases.
1007+
(zombieAttempts ++ Seq(finalTsm)).foreach { tsm =>
1008+
val stageAttempt = tsm.taskSet.stageAttemptId
1009+
tsm.runningTasksSet.foreach { index =>
1010+
if (stageAttempt == 1) {
1011+
tsm.handleFailedTask(tsm.taskInfos(index).taskId, TaskState.FAILED, TaskResultLost)
1012+
} else {
1013+
val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq())
1014+
tsm.handleSuccessfulTask(tsm.taskInfos(index).taskId, result)
1015+
}
1016+
}
1017+
1018+
// we update the blacklist for the stage attempts with all successful tasks. Even though
1019+
// some tasksets had failures, we still consider them all successful from a blacklisting
1020+
// perspective, as the failures weren't from a problem w/ the tasks themselves.
1021+
verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject())
1022+
}
1023+
}
9201024
}

0 commit comments

Comments
 (0)