@@ -26,35 +26,33 @@ import org.apache.spark._
26
26
27
27
class DAGSchedulerFailureRecoverySuite extends SparkFunSuite with Logging {
28
28
29
- // TODO we should run this with a matrix of configurations: different shufflers,
30
- // external shuffle service, etc. But that is really pushing the question of how to run
31
- // such a long test ...
32
-
33
- ignore(" no concurrent retries for stage attempts (SPARK-7308)" ) {
34
- // see SPARK-7308 for a detailed description of the conditions this is trying to recreate.
35
- // note that this is somewhat convoluted for a test case, but isn't actually very unusual
36
- // under a real workload. We only fail the first attempt of stage 2, but that
37
- // could be enough to cause havoc.
38
-
39
- (0 until 100 ).foreach { idx =>
40
- println(new Date () + " \t trial " + idx)
29
+ test(" no concurrent retries for stage attempts (SPARK-8103)" ) {
30
+ // make sure that if we get fetch failures after the retry has started, we ignore them,
31
+ // and so don't end up submitting multiple concurrent attempts for the same stage
32
+
33
+ (0 until 20 ).foreach { idx =>
41
34
logInfo(new Date () + " \t trial " + idx)
42
35
43
36
val conf = new SparkConf ().set(" spark.executor.memory" , " 100m" )
44
- val clusterSc = new SparkContext (" local-cluster[5,4 ,100]" , " test-cluster" , conf)
37
+ val clusterSc = new SparkContext (" local-cluster[2,2 ,100]" , " test-cluster" , conf)
45
38
val bms = ArrayBuffer [BlockManagerId ]()
46
39
val stageFailureCount = HashMap [Int , Int ]()
40
+ val stageSubmissionCount = HashMap [Int , Int ]()
47
41
clusterSc.addSparkListener(new SparkListener {
48
42
override def onBlockManagerAdded (bmAdded : SparkListenerBlockManagerAdded ): Unit = {
49
43
bms += bmAdded.blockManagerId
50
44
}
51
45
46
+ override def onStageSubmitted (stageSubmited : SparkListenerStageSubmitted ): Unit = {
47
+ val stage = stageSubmited.stageInfo.stageId
48
+ stageSubmissionCount(stage) = stageSubmissionCount.getOrElse(stage, 0 ) + 1
49
+ }
50
+
51
+
52
52
override def onStageCompleted (stageCompleted : SparkListenerStageCompleted ): Unit = {
53
53
if (stageCompleted.stageInfo.failureReason.isDefined) {
54
54
val stage = stageCompleted.stageInfo.stageId
55
55
stageFailureCount(stage) = stageFailureCount.getOrElse(stage, 0 ) + 1
56
- val reason = stageCompleted.stageInfo.failureReason.get
57
- println(" stage " + stage + " failed: " + stageFailureCount(stage))
58
56
}
59
57
}
60
58
})
@@ -66,34 +64,37 @@ class DAGSchedulerFailureRecoverySuite extends SparkFunSuite with Logging {
66
64
// to avoid broadcast failures
67
65
val someBlockManager = bms.filter{! _.isDriver}(0 )
68
66
69
- val shuffled = rawData.groupByKey(100 ).mapPartitionsWithIndex { case (idx, itr) =>
67
+ val shuffled = rawData.groupByKey(20 ).mapPartitionsWithIndex { case (idx, itr) =>
70
68
// we want one failure quickly, and more failures after stage 0 has finished its
71
69
// second attempt
72
70
val stageAttemptId = TaskContext .get().asInstanceOf [TaskContextImpl ].stageAttemptId
73
71
if (stageAttemptId == 0 ) {
74
72
if (idx == 0 ) {
75
73
throw new FetchFailedException (someBlockManager, 0 , 0 , idx,
76
74
cause = new RuntimeException (" simulated fetch failure" ))
77
- } else if (idx > 0 && math.random < 0.2 ) {
78
- Thread .sleep(5000 )
75
+ } else if (idx == 1 ) {
76
+ Thread .sleep(2000 )
79
77
throw new FetchFailedException (someBlockManager, 0 , 0 , idx,
80
78
cause = new RuntimeException (" simulated fetch failure" ))
81
- } else {
82
- // want to make sure plenty of these finish after task 0 fails, and some even finish
83
- // after the previous stage is retried and this stage retry is started
84
- Thread .sleep((500 + math.random * 5000 ).toLong)
85
79
}
80
+ } else {
81
+ // just to make sure the second attempt doesn't finish before we trigger more failures
82
+ // from the first attempt
83
+ Thread .sleep(2000 )
86
84
}
87
85
itr.map { x => ((x._1 + 5 ) % 100 ) -> x._2 }
88
86
}
89
- val data = shuffled.mapPartitions { itr => itr.flatMap(_._2) }.collect()
87
+ val data = shuffled.mapPartitions { itr =>
88
+ itr.flatMap(_._2)
89
+ }.cache().collect()
90
90
val count = data.size
91
91
assert(count === 1e6 .toInt)
92
92
assert(data.toSet === (1 to 1e6 .toInt).toSet)
93
93
94
94
assert(stageFailureCount.getOrElse(1 , 0 ) === 0 )
95
- assert(stageFailureCount.getOrElse(2 , 0 ) == 1 )
96
- assert(stageFailureCount.getOrElse(3 , 0 ) == 0 )
95
+ assert(stageFailureCount.getOrElse(2 , 0 ) === 1 )
96
+ assert(stageSubmissionCount.getOrElse(1 , 0 ) <= 2 )
97
+ assert(stageSubmissionCount.getOrElse(2 , 0 ) === 2 )
97
98
} finally {
98
99
clusterSc.stop()
99
100
}
0 commit comments