Skip to content

Commit 05b85eb

Browse files
committed
[SPARK-27474][CORE] avoid retrying a task failed with CommitDeniedException many times
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-25250 reports a bug that, a task which is failed with `CommitDeniedException` gets retried many times. This can happen when a stage has 2 task set managers, one is zombie, one is active. A task from the zombie TSM completes, and commits to a central coordinator(assuming it's a file writing task). Then the corresponding task from the active TSM will fail with `CommitDeniedException`. `CommitDeniedException.countTowardsTaskFailures` is false, so the active TSM will keep retrying this task, until the job finishes. This wastes resource a lot. #21131 firstly implements that a previous successful completed task from zombie `TaskSetManager` could mark the task of the same partition completed in the active `TaskSetManager`. Later #23871 improves the implementation to cover a corner case that, an active `TaskSetManager` hasn't been created when a previous task succeed. However, #23871 has a bug and was reverted in #24359. With hindsight, #23781 is fragile because we need to sync the states between `DAGScheduler` and `TaskScheduler`, about which partitions are completed. This PR proposes a new fix: 1. When `DAGScheduler` gets a task success event from an earlier attempt, notify the `TaskSchedulerImpl` about it 2. When `TaskSchedulerImpl` knows a partition is already completed, ask the active `TaskSetManager` to mark the corresponding task as finished, if the task is not finished yet. This fix covers the corner case, because: 1. If `DAGScheduler` gets the task completion event from zombie TSM before submitting the new stage attempt, then `DAGScheduler` knows that this partition is completed, and it will exclude this partition when creating task set for the new stage attempt. See `DAGScheduler.submitMissingTasks` 2. If `DAGScheduler` gets the task completion event from zombie TSM after submitting the new stage attempt, then the active TSM is already created. Compared to the previous fix, the message loop becomes longer, so it's likely that, the active task set manager has already retried the task multiple times. But this failure window won't be too big, and we want to avoid the worse case that retries the task many times until the job finishes. So this solution is acceptable. ## How was this patch tested? a new test case. Closes #24375 from cloud-fan/fix2. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 20a3ef7 commit 05b85eb

File tree

9 files changed

+125
-158
lines changed

9 files changed

+125
-158
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,6 +1389,15 @@ private[spark] class DAGScheduler(
13891389

13901390
event.reason match {
13911391
case Success =>
1392+
// An earlier attempt of a stage (which is zombie) may still have running tasks. If these
1393+
// tasks complete, they still count and we can mark the corresponding partitions as
1394+
// finished. Here we notify the task scheduler to skip running tasks for the same partition,
1395+
// to save resource.
1396+
if (task.stageAttemptId < stage.latestInfo.attemptNumber()) {
1397+
taskScheduler.notifyPartitionCompletion(
1398+
stageId, task.partitionId, event.taskInfo.duration)
1399+
}
1400+
13921401
task match {
13931402
case rt: ResultTask[_, _] =>
13941403
// Cast to ResultStage here because it's part of the ResultTask

core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,16 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
155155
}
156156
}
157157

158+
// This method calls `TaskSchedulerImpl.handlePartitionCompleted` asynchronously. We do not want
159+
// DAGScheduler to call `TaskSchedulerImpl.handlePartitionCompleted` directly, as it's
160+
// synchronized and may hurt the throughput of the scheduler.
161+
def enqueuePartitionCompletionNotification(
162+
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
163+
getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions {
164+
scheduler.handlePartitionCompleted(stageId, partitionId, taskDuration)
165+
})
166+
}
167+
158168
def stop() {
159169
getTaskResultExecutor.shutdownNow()
160170
}

core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ private[spark] trait TaskScheduler {
6868
// Throw UnsupportedOperationException if the backend doesn't support kill tasks.
6969
def killAllTaskAttempts(stageId: Int, interruptThread: Boolean, reason: String): Unit
7070

71+
// Notify the corresponding `TaskSetManager`s of the stage, that a partition has already completed
72+
// and they can skip running tasks for it.
73+
def notifyPartitionCompletion(stageId: Int, partitionId: Int, taskDuration: Long)
74+
7175
// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
7276
def setDAGScheduler(dagScheduler: DAGScheduler): Unit
7377

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,11 @@ private[spark] class TaskSchedulerImpl(
301301
}
302302
}
303303

304+
override def notifyPartitionCompletion(
305+
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
306+
taskResultGetter.enqueuePartitionCompletionNotification(stageId, partitionId, taskDuration)
307+
}
308+
304309
/**
305310
* Called to indicate that all task attempts (including speculated tasks) associated with the
306311
* given TaskSetManager have completed, so state associated with the TaskSetManager should be
@@ -637,6 +642,24 @@ private[spark] class TaskSchedulerImpl(
637642
}
638643
}
639644

645+
/**
646+
* Marks the task has completed in the active TaskSetManager for the given stage.
647+
*
648+
* After stage failure and retry, there may be multiple TaskSetManagers for the stage.
649+
* If an earlier zombie attempt of a stage completes a task, we can ask the later active attempt
650+
* to skip submitting and running the task for the same partition, to save resource. That also
651+
* means that a task completion from an earlier zombie attempt can lead to the entire stage
652+
* getting marked as successful.
653+
*/
654+
private[scheduler] def handlePartitionCompleted(
655+
stageId: Int,
656+
partitionId: Int,
657+
taskDuration: Long) = synchronized {
658+
taskSetsByStageIdAndAttempt.get(stageId).foreach(_.values.filter(!_.isZombie).foreach { tsm =>
659+
tsm.markPartitionCompleted(partitionId, taskDuration)
660+
})
661+
}
662+
640663
def error(message: String) {
641664
synchronized {
642665
if (taskSetsByStageIdAndAttempt.nonEmpty) {
@@ -868,24 +891,6 @@ private[spark] class TaskSchedulerImpl(
868891
manager
869892
}
870893
}
871-
872-
/**
873-
* Marks the task has completed in all TaskSetManagers for the given stage.
874-
*
875-
* After stage failure and retry, there may be multiple TaskSetManagers for the stage.
876-
* If an earlier attempt of a stage completes a task, we should ensure that the later attempts
877-
* do not also submit those same tasks. That also means that a task completion from an earlier
878-
* attempt can lead to the entire stage getting marked as successful.
879-
*/
880-
private[scheduler] def markPartitionCompletedInAllTaskSets(
881-
stageId: Int,
882-
partitionId: Int,
883-
taskInfo: TaskInfo) = {
884-
taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm =>
885-
tsm.markPartitionCompleted(partitionId, taskInfo)
886-
}
887-
}
888-
889894
}
890895

891896

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -806,9 +806,6 @@ private[spark] class TaskSetManager(
806806
logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id +
807807
" because task " + index + " has already completed successfully")
808808
}
809-
// There may be multiple tasksets for this stage -- we let all of them know that the partition
810-
// was completed. This may result in some of the tasksets getting completed.
811-
sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId, info)
812809
// This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the
813810
// "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not
814811
// "deserialize" the value when holding a lock to avoid blocking other threads. So we call
@@ -819,11 +816,11 @@ private[spark] class TaskSetManager(
819816
maybeFinishTaskSet()
820817
}
821818

822-
private[scheduler] def markPartitionCompleted(partitionId: Int, taskInfo: TaskInfo): Unit = {
819+
private[scheduler] def markPartitionCompleted(partitionId: Int, taskDuration: Long): Unit = {
823820
partitionToIndex.get(partitionId).foreach { index =>
824821
if (!successful(index)) {
825822
if (speculationEnabled && !isZombie) {
826-
successfulTaskDurations.insert(taskInfo.duration)
823+
successfulTaskDurations.insert(taskDuration)
827824
}
828825
tasksSuccessful += 1
829826
successful(index) = true

core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
134134
/** Stages for which the DAGScheduler has called TaskScheduler.cancelTasks(). */
135135
val cancelledStages = new HashSet[Int]()
136136

137+
val tasksMarkedAsCompleted = new ArrayBuffer[Task[_]]()
138+
137139
val taskScheduler = new TaskScheduler() {
138140
override def schedulingMode: SchedulingMode = SchedulingMode.FIFO
139141
override def rootPool: Pool = new Pool("", schedulingMode, 0, 0)
@@ -156,6 +158,14 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
156158
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
157159
override def killAllTaskAttempts(
158160
stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
161+
override def notifyPartitionCompletion(
162+
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
163+
taskSets.filter(_.stageId == stageId).lastOption.foreach { ts =>
164+
val tasks = ts.tasks.filter(_.partitionId == partitionId)
165+
assert(tasks.length == 1)
166+
tasksMarkedAsCompleted += tasks.head
167+
}
168+
}
159169
override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
160170
override def defaultParallelism() = 2
161171
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
@@ -246,6 +256,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
246256
failure = null
247257
sc.addSparkListener(sparkListener)
248258
taskSets.clear()
259+
tasksMarkedAsCompleted.clear()
249260
cancelledStages.clear()
250261
cacheLocations.clear()
251262
results.clear()
@@ -658,6 +669,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
658669
stageId: Int, interruptThread: Boolean, reason: String): Unit = {
659670
throw new UnsupportedOperationException
660671
}
672+
override def notifyPartitionCompletion(
673+
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
674+
throw new UnsupportedOperationException
675+
}
661676
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
662677
override def defaultParallelism(): Int = 2
663678
override def executorHeartbeatReceived(
@@ -2862,6 +2877,57 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
28622877
assert(latch.await(10, TimeUnit.SECONDS))
28632878
}
28642879

2880+
test("Completions in zombie tasksets update status of non-zombie taskset") {
2881+
val parts = 4
2882+
val shuffleMapRdd = new MyRDD(sc, parts, Nil)
2883+
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts))
2884+
val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker)
2885+
submit(reduceRdd, (0 until parts).toArray)
2886+
assert(taskSets.length == 1)
2887+
2888+
// Finish the first task of the shuffle map stage.
2889+
runEvent(makeCompletionEvent(
2890+
taskSets(0).tasks(0), Success, makeMapStatus("hostA", 4),
2891+
Seq.empty, createFakeTaskInfoWithId(0)))
2892+
2893+
// The second task of the shuffle map stage failed with FetchFailed.
2894+
runEvent(makeCompletionEvent(
2895+
taskSets(0).tasks(1),
2896+
FetchFailed(makeBlockManagerId("hostB"), shuffleDep.shuffleId, 0, 0, "ignored"),
2897+
null))
2898+
2899+
scheduler.resubmitFailedStages()
2900+
assert(taskSets.length == 2)
2901+
// The first partition has completed already, so the new attempt only need to run 3 tasks.
2902+
assert(taskSets(1).tasks.length == 3)
2903+
2904+
// Finish the first task of the second attempt of the shuffle map stage.
2905+
runEvent(makeCompletionEvent(
2906+
taskSets(1).tasks(0), Success, makeMapStatus("hostA", 4),
2907+
Seq.empty, createFakeTaskInfoWithId(0)))
2908+
2909+
// Finish the third task of the first attempt of the shuffle map stage.
2910+
runEvent(makeCompletionEvent(
2911+
taskSets(0).tasks(2), Success, makeMapStatus("hostA", 4),
2912+
Seq.empty, createFakeTaskInfoWithId(0)))
2913+
assert(tasksMarkedAsCompleted.length == 1)
2914+
assert(tasksMarkedAsCompleted.head.partitionId == 2)
2915+
2916+
// Finish the forth task of the first attempt of the shuffle map stage.
2917+
runEvent(makeCompletionEvent(
2918+
taskSets(0).tasks(3), Success, makeMapStatus("hostA", 4),
2919+
Seq.empty, createFakeTaskInfoWithId(0)))
2920+
assert(tasksMarkedAsCompleted.length == 2)
2921+
assert(tasksMarkedAsCompleted.last.partitionId == 3)
2922+
2923+
// Now the shuffle map stage is completed, and the next stage is submitted.
2924+
assert(taskSets.length == 3)
2925+
2926+
// Finish
2927+
complete(taskSets(2), Seq((Success, 42), (Success, 42), (Success, 42), (Success, 42)))
2928+
assertDataStructuresEmpty()
2929+
}
2930+
28652931
/**
28662932
* Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
28672933
* Note that this checks only the host and not the executor ID.

core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ private class DummyTaskScheduler extends TaskScheduler {
8484
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
8585
override def killAllTaskAttempts(
8686
stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
87+
override def notifyPartitionCompletion(
88+
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {}
8789
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
8890
override def defaultParallelism(): Int = 2
8991
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}

core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala

Lines changed: 0 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,110 +1121,6 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
11211121
}
11221122
}
11231123

1124-
test("Completions in zombie tasksets update status of non-zombie taskset") {
1125-
val taskScheduler = setupSchedulerWithMockTaskSetBlacklist()
1126-
val valueSer = SparkEnv.get.serializer.newInstance()
1127-
1128-
def completeTaskSuccessfully(tsm: TaskSetManager, partition: Int): Unit = {
1129-
val indexInTsm = tsm.partitionToIndex(partition)
1130-
val matchingTaskInfo = tsm.taskAttempts.flatten.filter(_.index == indexInTsm).head
1131-
val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq())
1132-
tsm.handleSuccessfulTask(matchingTaskInfo.taskId, result)
1133-
}
1134-
1135-
// Submit a task set, have it fail with a fetch failed, and then re-submit the task attempt,
1136-
// two times, so we have three active task sets for one stage. (For this to really happen,
1137-
// you'd need the previous stage to also get restarted, and then succeed, in between each
1138-
// attempt, but that happens outside what we're mocking here.)
1139-
val zombieAttempts = (0 until 2).map { stageAttempt =>
1140-
val attempt = FakeTask.createTaskSet(10, stageAttemptId = stageAttempt)
1141-
taskScheduler.submitTasks(attempt)
1142-
val tsm = taskScheduler.taskSetManagerForAttempt(0, stageAttempt).get
1143-
val offers = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
1144-
taskScheduler.resourceOffers(offers)
1145-
assert(tsm.runningTasks === 10)
1146-
// fail attempt
1147-
tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED,
1148-
FetchFailed(null, 0, 0, 0, "fetch failed"))
1149-
// the attempt is a zombie, but the tasks are still running (this could be true even if
1150-
// we actively killed those tasks, as killing is best-effort)
1151-
assert(tsm.isZombie)
1152-
assert(tsm.runningTasks === 9)
1153-
tsm
1154-
}
1155-
1156-
// we've now got 2 zombie attempts, each with 9 tasks still active. Submit the 3rd attempt for
1157-
// the stage, but this time with insufficient resources so not all tasks are active.
1158-
1159-
val finalAttempt = FakeTask.createTaskSet(10, stageAttemptId = 2)
1160-
taskScheduler.submitTasks(finalAttempt)
1161-
val finalTsm = taskScheduler.taskSetManagerForAttempt(0, 2).get
1162-
val offers = (0 until 5).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
1163-
val finalAttemptLaunchedPartitions = taskScheduler.resourceOffers(offers).flatten.map { task =>
1164-
finalAttempt.tasks(task.index).partitionId
1165-
}.toSet
1166-
assert(finalTsm.runningTasks === 5)
1167-
assert(!finalTsm.isZombie)
1168-
1169-
// We simulate late completions from our zombie tasksets, corresponding to all the pending
1170-
// partitions in our final attempt. This means we're only waiting on the tasks we've already
1171-
// launched.
1172-
val finalAttemptPendingPartitions = (0 until 10).toSet.diff(finalAttemptLaunchedPartitions)
1173-
finalAttemptPendingPartitions.foreach { partition =>
1174-
completeTaskSuccessfully(zombieAttempts(0), partition)
1175-
}
1176-
1177-
// If there is another resource offer, we shouldn't run anything. Though our final attempt
1178-
// used to have pending tasks, now those tasks have been completed by zombie attempts. The
1179-
// remaining tasks to compute are already active in the non-zombie attempt.
1180-
assert(
1181-
taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec-1", "host-1", 1))).flatten.isEmpty)
1182-
1183-
val remainingTasks = finalAttemptLaunchedPartitions.toIndexedSeq.sorted
1184-
1185-
// finally, if we finish the remaining partitions from a mix of tasksets, all attempts should be
1186-
// marked as zombie.
1187-
// for each of the remaining tasks, find the tasksets with an active copy of the task, and
1188-
// finish the task.
1189-
remainingTasks.foreach { partition =>
1190-
val tsm = if (partition == 0) {
1191-
// we failed this task on both zombie attempts, this one is only present in the latest
1192-
// taskset
1193-
finalTsm
1194-
} else {
1195-
// should be active in every taskset. We choose a zombie taskset just to make sure that
1196-
// we transition the active taskset correctly even if the final completion comes
1197-
// from a zombie.
1198-
zombieAttempts(partition % 2)
1199-
}
1200-
completeTaskSuccessfully(tsm, partition)
1201-
}
1202-
1203-
assert(finalTsm.isZombie)
1204-
1205-
// no taskset has completed all of its tasks, so no updates to the blacklist tracker yet
1206-
verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), any())
1207-
1208-
// finally, lets complete all the tasks. We simulate failures in attempt 1, but everything
1209-
// else succeeds, to make sure we get the right updates to the blacklist in all cases.
1210-
(zombieAttempts ++ Seq(finalTsm)).foreach { tsm =>
1211-
val stageAttempt = tsm.taskSet.stageAttemptId
1212-
tsm.runningTasksSet.foreach { index =>
1213-
if (stageAttempt == 1) {
1214-
tsm.handleFailedTask(tsm.taskInfos(index).taskId, TaskState.FAILED, TaskResultLost)
1215-
} else {
1216-
val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq())
1217-
tsm.handleSuccessfulTask(tsm.taskInfos(index).taskId, result)
1218-
}
1219-
}
1220-
1221-
// we update the blacklist for the stage attempts with all successful tasks. Even though
1222-
// some tasksets had failures, we still consider them all successful from a blacklisting
1223-
// perspective, as the failures weren't from a problem w/ the tasks themselves.
1224-
verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), any())
1225-
}
1226-
}
1227-
12281124
test("don't schedule for a barrier taskSet if available slots are less than pending tasks") {
12291125
val taskCpus = 2
12301126
val taskScheduler = setupSchedulerWithMaster(

0 commit comments

Comments
 (0)