Skip to content

Commit c443def

Browse files
committed
better fix and simpler test case
1 parent 28d70aa commit c443def

File tree

2 files changed

+60
-56
lines changed

2 files changed

+60
-56
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,44 +1102,47 @@ class DAGScheduler(
11021102
case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) =>
11031103
val failedStage = stageIdToStage(task.stageId)
11041104
val mapStage = shuffleToMapStage(shuffleId)
1105+
if (failedStage.attemptId - 1 > task.stageAttemptId) {
1106+
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
1107+
s" ${task.stageAttemptId}, which has already failed")
1108+
} else {
11051109

1106-
// It is likely that we receive multiple FetchFailed for a single stage (because we have
1107-
// multiple tasks running concurrently on different executors). In that case, it is possible
1108-
// the fetch failure has already been handled by the scheduler.
1109-
if (runningStages.contains(failedStage)) {
1110-
if (failedStage.attemptId - 1 > task.stageAttemptId) {
1111-
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
1112-
s" ${task.stageAttemptId}, which has already failed")
1113-
} else {
1110+
// It is likely that we receive multiple FetchFailed for a single stage (because we have
1111+
// multiple tasks running concurrently on different executors). In that case, it is possible
1112+
// the fetch failure has already been handled by the scheduler.
1113+
if (runningStages.contains(failedStage)) {
11141114
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
11151115
s"due to a fetch failure from $mapStage (${mapStage.name})")
11161116
markStageAsFinished(failedStage, Some(failureMessage))
1117+
} else {
1118+
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage, " +
1119+
s"which is no longer running")
11171120
}
1118-
}
11191121

1120-
if (disallowStageRetryForTest) {
1121-
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
1122-
} else if (failedStages.isEmpty) {
1123-
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
1124-
// in that case the event will already have been scheduled.
1125-
// TODO: Cancel running tasks in the stage
1126-
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
1127-
s"$failedStage (${failedStage.name}) due to fetch failure")
1128-
messageScheduler.schedule(new Runnable {
1129-
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
1130-
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
1131-
}
1132-
failedStages += failedStage
1133-
failedStages += mapStage
1134-
// Mark the map whose fetch failed as broken in the map stage
1135-
if (mapId != -1) {
1136-
mapStage.removeOutputLoc(mapId, bmAddress)
1137-
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
1138-
}
1122+
if (disallowStageRetryForTest) {
1123+
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
1124+
} else if (failedStages.isEmpty) {
1125+
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
1126+
// in that case the event will already have been scheduled.
1127+
// TODO: Cancel running tasks in the stage
1128+
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
1129+
s"$failedStage (${failedStage.name}) due to fetch failure")
1130+
messageScheduler.schedule(new Runnable {
1131+
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
1132+
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
1133+
}
1134+
failedStages += failedStage
1135+
failedStages += mapStage
1136+
// Mark the map whose fetch failed as broken in the map stage
1137+
if (mapId != -1) {
1138+
mapStage.removeOutputLoc(mapId, bmAddress)
1139+
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
1140+
}
11391141

1140-
// TODO: mark the executor as failed only if there were lots of fetch failures on it
1141-
if (bmAddress != null) {
1142-
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
1142+
// TODO: mark the executor as failed only if there were lots of fetch failures on it
1143+
if (bmAddress != null) {
1144+
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
1145+
}
11431146
}
11441147

11451148
case commitDenied: TaskCommitDenied =>

core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerFailureRecoverySuite.scala

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,35 +26,33 @@ import org.apache.spark._
2626

2727
class DAGSchedulerFailureRecoverySuite extends SparkFunSuite with Logging {
2828

29-
// TODO we should run this with a matrix of configurations: different shufflers,
30-
// external shuffle service, etc. But that is really pushing the question of how to run
31-
// such a long test ...
32-
33-
ignore("no concurrent retries for stage attempts (SPARK-7308)") {
34-
// see SPARK-7308 for a detailed description of the conditions this is trying to recreate.
35-
// note that this is somewhat convoluted for a test case, but isn't actually very unusual
36-
// under a real workload. We only fail the first attempt of stage 2, but that
37-
// could be enough to cause havoc.
38-
39-
(0 until 100).foreach { idx =>
40-
println(new Date() + "\ttrial " + idx)
29+
test("no concurrent retries for stage attempts (SPARK-8103)") {
30+
// make sure that if we get fetch failures after the retry has started, we ignore them,
31+
// and so don't end up submitting multiple concurrent attempts for the same stage
32+
33+
(0 until 20).foreach { idx =>
4134
logInfo(new Date() + "\ttrial " + idx)
4235

4336
val conf = new SparkConf().set("spark.executor.memory", "100m")
44-
val clusterSc = new SparkContext("local-cluster[5,4,100]", "test-cluster", conf)
37+
val clusterSc = new SparkContext("local-cluster[2,2,100]", "test-cluster", conf)
4538
val bms = ArrayBuffer[BlockManagerId]()
4639
val stageFailureCount = HashMap[Int, Int]()
40+
val stageSubmissionCount = HashMap[Int, Int]()
4741
clusterSc.addSparkListener(new SparkListener {
4842
override def onBlockManagerAdded(bmAdded: SparkListenerBlockManagerAdded): Unit = {
4943
bms += bmAdded.blockManagerId
5044
}
5145

46+
override def onStageSubmitted(stageSubmited: SparkListenerStageSubmitted): Unit = {
47+
val stage = stageSubmited.stageInfo.stageId
48+
stageSubmissionCount(stage) = stageSubmissionCount.getOrElse(stage, 0) + 1
49+
}
50+
51+
5252
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
5353
if (stageCompleted.stageInfo.failureReason.isDefined) {
5454
val stage = stageCompleted.stageInfo.stageId
5555
stageFailureCount(stage) = stageFailureCount.getOrElse(stage, 0) + 1
56-
val reason = stageCompleted.stageInfo.failureReason.get
57-
println("stage " + stage + " failed: " + stageFailureCount(stage))
5856
}
5957
}
6058
})
@@ -66,34 +64,37 @@ class DAGSchedulerFailureRecoverySuite extends SparkFunSuite with Logging {
6664
// to avoid broadcast failures
6765
val someBlockManager = bms.filter{!_.isDriver}(0)
6866

69-
val shuffled = rawData.groupByKey(100).mapPartitionsWithIndex { case (idx, itr) =>
67+
val shuffled = rawData.groupByKey(20).mapPartitionsWithIndex { case (idx, itr) =>
7068
// we want one failure quickly, and more failures after stage 0 has finished its
7169
// second attempt
7270
val stageAttemptId = TaskContext.get().asInstanceOf[TaskContextImpl].stageAttemptId
7371
if (stageAttemptId == 0) {
7472
if (idx == 0) {
7573
throw new FetchFailedException(someBlockManager, 0, 0, idx,
7674
cause = new RuntimeException("simulated fetch failure"))
77-
} else if (idx > 0 && math.random < 0.2) {
78-
Thread.sleep(5000)
75+
} else if (idx == 1) {
76+
Thread.sleep(2000)
7977
throw new FetchFailedException(someBlockManager, 0, 0, idx,
8078
cause = new RuntimeException("simulated fetch failure"))
81-
} else {
82-
// want to make sure plenty of these finish after task 0 fails, and some even finish
83-
// after the previous stage is retried and this stage retry is started
84-
Thread.sleep((500 + math.random * 5000).toLong)
8579
}
80+
} else {
81+
// just to make sure the second attempt doesn't finish before we trigger more failures
82+
// from the first attempt
83+
Thread.sleep(2000)
8684
}
8785
itr.map { x => ((x._1 + 5) % 100) -> x._2 }
8886
}
89-
val data = shuffled.mapPartitions { itr => itr.flatMap(_._2) }.collect()
87+
val data = shuffled.mapPartitions { itr =>
88+
itr.flatMap(_._2)
89+
}.cache().collect()
9090
val count = data.size
9191
assert(count === 1e6.toInt)
9292
assert(data.toSet === (1 to 1e6.toInt).toSet)
9393

9494
assert(stageFailureCount.getOrElse(1, 0) === 0)
95-
assert(stageFailureCount.getOrElse(2, 0) == 1)
96-
assert(stageFailureCount.getOrElse(3, 0) == 0)
95+
assert(stageFailureCount.getOrElse(2, 0) === 1)
96+
assert(stageSubmissionCount.getOrElse(1, 0) <= 2)
97+
assert(stageSubmissionCount.getOrElse(2, 0) === 2)
9798
} finally {
9899
clusterSc.stop()
99100
}

0 commit comments

Comments
 (0)