@@ -101,9 +101,15 @@ class DAGSchedulerSuite
101
101
/** Length of time to wait while draining listener events. */
102
102
val WAIT_TIMEOUT_MILLIS = 10000
103
103
val sparkListener = new SparkListener () {
104
+ val submittedStageInfos = new HashSet [StageInfo ]
104
105
val successfulStages = new HashSet [Int ]
105
106
val failedStages = new ArrayBuffer [Int ]
106
107
val stageByOrderOfExecution = new ArrayBuffer [Int ]
108
+
109
+ override def onStageSubmitted (stageSubmitted : SparkListenerStageSubmitted ) {
110
+ submittedStageInfos += stageSubmitted.stageInfo
111
+ }
112
+
107
113
override def onStageCompleted (stageCompleted : SparkListenerStageCompleted ) {
108
114
val stageInfo = stageCompleted.stageInfo
109
115
stageByOrderOfExecution += stageInfo.stageId
@@ -150,6 +156,7 @@ class DAGSchedulerSuite
150
156
// Enable local execution for this test
151
157
val conf = new SparkConf ().set(" spark.localExecution.enabled" , " true" )
152
158
sc = new SparkContext (" local" , " DAGSchedulerSuite" , conf)
159
+ sparkListener.submittedStageInfos.clear()
153
160
sparkListener.successfulStages.clear()
154
161
sparkListener.failedStages.clear()
155
162
failure = null
@@ -547,6 +554,133 @@ class DAGSchedulerSuite
547
554
assert(sparkListener.failedStages.size == 1 )
548
555
}
549
556
557
+ /** This tests the case where another FetchFailed comes in while the map stage is getting
558
+ * re-run. */
559
+ test(" late fetch failures don't cause multiple concurrent attempts for the same map stage" ) {
560
+ val shuffleMapRdd = new MyRDD (sc, 2 , Nil )
561
+ val shuffleDep = new ShuffleDependency (shuffleMapRdd, null )
562
+ val shuffleId = shuffleDep.shuffleId
563
+ val reduceRdd = new MyRDD (sc, 2 , List (shuffleDep))
564
+ submit(reduceRdd, Array (0 , 1 ))
565
+
566
+ val mapStageId = 0
567
+ def countSubmittedMapStageAttempts (): Int = {
568
+ sparkListener.submittedStageInfos.count(_.stageId == mapStageId)
569
+ }
570
+
571
+ // The map stage should have been submitted.
572
+ assert(countSubmittedMapStageAttempts() === 1 )
573
+
574
+ complete(taskSets(0 ), Seq (
575
+ (Success , makeMapStatus(" hostA" , 1 )),
576
+ (Success , makeMapStatus(" hostB" , 1 ))))
577
+ // The MapOutputTracker should know about both map output locations.
578
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0 ).map(_._1.host) ===
579
+ Array (" hostA" , " hostB" ))
580
+
581
+ // The first result task fails, with a fetch failure for the output from the first mapper.
582
+ runEvent(CompletionEvent (
583
+ taskSets(1 ).tasks(0 ),
584
+ FetchFailed (makeBlockManagerId(" hostA" ), shuffleId, 0 , 0 , " ignored" ),
585
+ null ,
586
+ Map [Long , Any ](),
587
+ createFakeTaskInfo(),
588
+ null ))
589
+ assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
590
+ assert(sparkListener.failedStages.contains(1 ))
591
+
592
+ // Trigger resubmission of the failed map stage.
593
+ runEvent(ResubmitFailedStages )
594
+ assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
595
+
596
+ // Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
597
+ assert(countSubmittedMapStageAttempts() === 2 )
598
+
599
+ // The second ResultTask fails, with a fetch failure for the output from the second mapper.
600
+ runEvent(CompletionEvent (
601
+ taskSets(1 ).tasks(1 ),
602
+ FetchFailed (makeBlockManagerId(" hostB" ), shuffleId, 1 , 1 , " ignored" ),
603
+ null ,
604
+ Map [Long , Any ](),
605
+ createFakeTaskInfo(),
606
+ null ))
607
+
608
+ // Another ResubmitFailedStages event should not result result in another attempt for the map
609
+ // stage being run concurrently.
610
+ runEvent(ResubmitFailedStages )
611
+ assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
612
+ assert(countSubmittedMapStageAttempts() === 2 )
613
+
614
+ // NOTE: the actual ResubmitFailedStages may get called at any time during this, shouldn't effect anything --
615
+ // our calling it just makes *SURE* it gets called between the desired event and our check.
616
+
617
+ }
618
+
619
+ /** This tests the case where a late FetchFailed comes in after the map stage has finished getting
620
+ * retried and a new reduce stage starts running.
621
+ */
622
+ test(" extremely late fetch failures don't cause multiple concurrent attempts for the same stage" ) {
623
+ val shuffleMapRdd = new MyRDD (sc, 2 , Nil )
624
+ val shuffleDep = new ShuffleDependency (shuffleMapRdd, null )
625
+ val shuffleId = shuffleDep.shuffleId
626
+ val reduceRdd = new MyRDD (sc, 2 , List (shuffleDep))
627
+ submit(reduceRdd, Array (0 , 1 ))
628
+
629
+ def countSubmittedReduceStageAttempts (): Int = {
630
+ sparkListener.submittedStageInfos.count(_.stageId == 1 )
631
+ }
632
+ def countSubmittedMapStageAttempts (): Int = {
633
+ sparkListener.submittedStageInfos.count(_.stageId == 0 )
634
+ }
635
+
636
+ // The map stage should have been submitted.
637
+ assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
638
+ assert(countSubmittedMapStageAttempts() === 1 )
639
+
640
+ // Complete the map stage.
641
+ complete(taskSets(0 ), Seq (
642
+ (Success , makeMapStatus(" hostA" , 1 )),
643
+ (Success , makeMapStatus(" hostB" , 1 ))))
644
+
645
+ // The reduce stage should have been submitted.
646
+ assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
647
+ assert(countSubmittedReduceStageAttempts() === 1 )
648
+
649
+ // The first result task fails, with a fetch failure for the output from the first mapper.
650
+ runEvent(CompletionEvent (
651
+ taskSets(1 ).tasks(0 ),
652
+ FetchFailed (makeBlockManagerId(" hostA" ), shuffleId, 0 , 0 , " ignored" ),
653
+ null ,
654
+ Map [Long , Any ](),
655
+ createFakeTaskInfo(),
656
+ null ))
657
+
658
+ // Trigger resubmission of the failed map stage and finish the re-started map task.
659
+ runEvent(ResubmitFailedStages )
660
+ complete(taskSets(2 ), Seq ((Success , makeMapStatus(" hostA" , 1 ))))
661
+
662
+ // Because the map stage finished, another attempt for the reduce stage should have been
663
+ // submitted, resulting in 2 total attempts for each the map and the reduce stage.
664
+ assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
665
+ assert(countSubmittedMapStageAttempts() === 2 )
666
+ assert(countSubmittedReduceStageAttempts() === 2 )
667
+
668
+ // A late FetchFailed arrives from the second task in the original reduce stage.
669
+ runEvent(CompletionEvent (
670
+ taskSets(1 ).tasks(1 ),
671
+ FetchFailed (makeBlockManagerId(" hostB" ), shuffleId, 1 , 1 , " ignored" ),
672
+ null ,
673
+ Map [Long , Any ](),
674
+ createFakeTaskInfo(),
675
+ null ))
676
+
677
+ // Trigger resubmission of the failed map stage and finish the re-started map task.
678
+ runEvent(ResubmitFailedStages )
679
+
680
+ // The FetchFailed from the original reduce stage should be ignored.
681
+ assert(countSubmittedMapStageAttempts() === 2 )
682
+ }
683
+
550
684
test(" ignore late map task completions" ) {
551
685
val shuffleMapRdd = new MyRDD (sc, 2 , Nil )
552
686
val shuffleDep = new ShuffleDependency (shuffleMapRdd, null )
0 commit comments