@@ -85,12 +85,9 @@ class DAGScheduler(
8585 private val nextStageId = new AtomicInteger (0 )
8686
8787 private [scheduler] val jobIdToStageIds = new HashMap [Int , HashSet [Int ]]
88- private [scheduler] val stageIdToJobIds = new HashMap [Int , HashSet [Int ]]
8988 private [scheduler] val stageIdToStage = new HashMap [Int , Stage ]
9089 private [scheduler] val shuffleToMapStage = new HashMap [Int , Stage ]
9190 private [scheduler] val jobIdToActiveJob = new HashMap [Int , ActiveJob ]
92- private [scheduler] val resultStageToJob = new HashMap [Stage , ActiveJob ]
93- private [scheduler] val stageToInfos = new HashMap [Stage , StageInfo ]
9491
9592 // Stages we need to run whose parents aren't done
9693 private [scheduler] val waitingStages = new HashSet [Stage ]
@@ -101,9 +98,6 @@ class DAGScheduler(
10198 // Stages that must be resubmitted due to fetch failures
10299 private [scheduler] val failedStages = new HashSet [Stage ]
103100
104- // Missing tasks from each stage
105- private [scheduler] val pendingTasks = new HashMap [Stage , HashSet [Task [_]]]
106-
107101 private [scheduler] val activeJobs = new HashSet [ActiveJob ]
108102
109103 // Contains the locations that each RDD's partitions are cached on
@@ -223,7 +217,6 @@ class DAGScheduler(
223217 new Stage (id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
224218 stageIdToStage(id) = stage
225219 updateJobIdStageIdMaps(jobId, stage)
226- stageToInfos(stage) = StageInfo .fromStage(stage)
227220 stage
228221 }
229222
@@ -315,13 +308,12 @@ class DAGScheduler(
315308 */
316309 private def updateJobIdStageIdMaps (jobId : Int , stage : Stage ) {
317310 def updateJobIdStageIdMapsList (stages : List [Stage ]) {
318- if (! stages.isEmpty ) {
311+ if (stages.nonEmpty ) {
319312 val s = stages.head
320- stageIdToJobIds.getOrElseUpdate(s.id, new HashSet [ Int ]()) += jobId
313+ s.jobIds += jobId
321314 jobIdToStageIds.getOrElseUpdate(jobId, new HashSet [Int ]()) += s.id
322- val parents = getParentStages(s.rdd, jobId)
323- val parentsWithoutThisJobId = parents.filter(p =>
324- ! stageIdToJobIds.get(p.id).exists(_.contains(jobId)))
315+ val parents : List [Stage ] = getParentStages(s.rdd, jobId)
316+ val parentsWithoutThisJobId = parents.filter { ! _.jobIds.contains(jobId) }
325317 updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail)
326318 }
327319 }
@@ -333,16 +325,15 @@ class DAGScheduler(
333325 * handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks.
334326 *
335327 * @param job The job whose state to cleanup.
336- * @param resultStage Specifies the result stage for the job; if set to None, this method
337- * searches resultStagesToJob to find and cleanup the appropriate result stage.
338328 */
339- private def cleanupStateForJobAndIndependentStages (job : ActiveJob , resultStage : Option [ Stage ] ) {
329+ private def cleanupStateForJobAndIndependentStages (job : ActiveJob ) {
340330 val registeredStages = jobIdToStageIds.get(job.jobId)
341331 if (registeredStages.isEmpty || registeredStages.get.isEmpty) {
342332 logError(" No stages registered for job " + job.jobId)
343333 } else {
344- stageIdToJobIds.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach {
345- case (stageId, jobSet) =>
334+ stageIdToStage.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach {
335+ case (stageId, stage) =>
336+ val jobSet = stage.jobIds
346337 if (! jobSet.contains(job.jobId)) {
347338 logError(
348339 " Job %d not registered for stage %d even though that stage was registered for the job"
@@ -355,14 +346,9 @@ class DAGScheduler(
355346 logDebug(" Removing running stage %d" .format(stageId))
356347 runningStages -= stage
357348 }
358- stageToInfos -= stage
359349 for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) {
360350 shuffleToMapStage.remove(k)
361351 }
362- if (pendingTasks.contains(stage) && ! pendingTasks(stage).isEmpty) {
363- logDebug(" Removing pending status for stage %d" .format(stageId))
364- }
365- pendingTasks -= stage
366352 if (waitingStages.contains(stage)) {
367353 logDebug(" Removing stage %d from waiting set." .format(stageId))
368354 waitingStages -= stage
@@ -374,7 +360,6 @@ class DAGScheduler(
374360 }
375361 // data structures based on StageId
376362 stageIdToStage -= stageId
377- stageIdToJobIds -= stageId
378363
379364 ShuffleMapTask .removeStage(stageId)
380365 ResultTask .removeStage(stageId)
@@ -393,19 +378,7 @@ class DAGScheduler(
393378 jobIdToStageIds -= job.jobId
394379 jobIdToActiveJob -= job.jobId
395380 activeJobs -= job
396-
397- if (resultStage.isEmpty) {
398- // Clean up result stages.
399- val resultStagesForJob = resultStageToJob.keySet.filter(
400- stage => resultStageToJob(stage).jobId == job.jobId)
401- if (resultStagesForJob.size != 1 ) {
402- logWarning(
403- s " ${resultStagesForJob.size} result stages for job ${job.jobId} (expect exactly 1) " )
404- }
405- resultStageToJob --= resultStagesForJob
406- } else {
407- resultStageToJob -= resultStage.get
408- }
381+ job.finalStage.resultOfJob = None
409382 }
410383
411384 /**
@@ -591,9 +564,10 @@ class DAGScheduler(
591564 job.listener.jobFailed(exception)
592565 } finally {
593566 val s = job.finalStage
594- stageIdToJobIds -= s.id // clean up data structures that were populated for a local job,
595- stageIdToStage -= s.id // but that won't get cleaned up via the normal paths through
596- stageToInfos -= s // completion events or stage abort
567+ // clean up data structures that were populated for a local job,
568+ // but that won't get cleaned up via the normal paths through
569+ // completion events or stage abort
570+ stageIdToStage -= s.id
597571 jobIdToStageIds -= job.jobId
598572 listenerBus.post(SparkListenerJobEnd (job.jobId, jobResult))
599573 }
@@ -605,12 +579,8 @@ class DAGScheduler(
605579 // That should take care of at least part of the priority inversion problem with
606580 // cross-job dependencies.
607581 private def activeJobForStage (stage : Stage ): Option [Int ] = {
608- if (stageIdToJobIds.contains(stage.id)) {
609- val jobsThatUseStage : Array [Int ] = stageIdToJobIds(stage.id).toArray.sorted
610- jobsThatUseStage.find(jobIdToActiveJob.contains)
611- } else {
612- None
613- }
582+ val jobsThatUseStage : Array [Int ] = stage.jobIds.toArray.sorted
583+ jobsThatUseStage.find(jobIdToActiveJob.contains)
614584 }
615585
616586 private [scheduler] def handleJobGroupCancelled (groupId : String ) {
@@ -642,9 +612,8 @@ class DAGScheduler(
642612 // is in the process of getting stopped.
643613 val stageFailedMessage = " Stage cancelled because SparkContext was shut down"
644614 runningStages.foreach { stage =>
645- val info = stageToInfos(stage)
646- info.stageFailed(stageFailedMessage)
647- listenerBus.post(SparkListenerStageCompleted (info))
615+ stage.info.stageFailed(stageFailedMessage)
616+ listenerBus.post(SparkListenerStageCompleted (stage.info))
648617 }
649618 listenerBus.post(SparkListenerJobEnd (job.jobId, JobFailed (error)))
650619 }
@@ -690,7 +659,7 @@ class DAGScheduler(
690659 } else {
691660 jobIdToActiveJob(jobId) = job
692661 activeJobs += job
693- resultStageToJob( finalStage) = job
662+ finalStage.resultOfJob = Some ( job)
694663 listenerBus.post(SparkListenerJobStart (job.jobId, jobIdToStageIds(jobId).toArray,
695664 properties))
696665 submitStage(finalStage)
@@ -727,8 +696,7 @@ class DAGScheduler(
727696 private def submitMissingTasks (stage : Stage , jobId : Int ) {
728697 logDebug(" submitMissingTasks(" + stage + " )" )
729698 // Get our pending tasks and remember them in our pendingTasks entry
730- val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet )
731- myPending.clear()
699+ stage.pendingTasks.clear()
732700 var tasks = ArrayBuffer [Task [_]]()
733701 if (stage.isShuffleMap) {
734702 for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil ) {
@@ -737,7 +705,7 @@ class DAGScheduler(
737705 }
738706 } else {
739707 // This is a final stage; figure out its job's missing partitions
740- val job = resultStageToJob( stage)
708+ val job = stage.resultOfJob.get
741709 for (id <- 0 until job.numPartitions if ! job.finished(id)) {
742710 val partition = job.partitions(id)
743711 val locs = getPreferredLocs(stage.rdd, partition)
@@ -758,7 +726,7 @@ class DAGScheduler(
758726 // serializable. If tasks are not serializable, a SparkListenerStageCompleted event
759727 // will be posted, which should always come after a corresponding SparkListenerStageSubmitted
760728 // event.
761- listenerBus.post(SparkListenerStageSubmitted (stageToInfos( stage) , properties))
729+ listenerBus.post(SparkListenerStageSubmitted (stage.info , properties))
762730
763731 // Preemptively serialize a task to make sure it can be serialized. We are catching this
764732 // exception here because it would be fairly hard to catch the non-serializable exception
@@ -778,11 +746,11 @@ class DAGScheduler(
778746 }
779747
780748 logInfo(" Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + " )" )
781- myPending ++= tasks
782- logDebug(" New pending tasks: " + myPending )
749+ stage.pendingTasks ++= tasks
750+ logDebug(" New pending tasks: " + stage.pendingTasks )
783751 taskScheduler.submitTasks(
784752 new TaskSet (tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
785- stageToInfos( stage) .submissionTime = Some (clock.getTime())
753+ stage.info .submissionTime = Some (clock.getTime())
786754 } else {
787755 logDebug(" Stage " + stage + " is actually done; %b %d %d" .format(
788756 stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@@ -807,13 +775,13 @@ class DAGScheduler(
807775 val stage = stageIdToStage(task.stageId)
808776
809777 def markStageAsFinished (stage : Stage ) = {
810- val serviceTime = stageToInfos( stage) .submissionTime match {
778+ val serviceTime = stage.info .submissionTime match {
811779 case Some (t) => " %.03f" .format((clock.getTime() - t) / 1000.0 )
812780 case _ => " Unknown"
813781 }
814782 logInfo(" %s (%s) finished in %s s" .format(stage, stage.name, serviceTime))
815- stageToInfos( stage) .completionTime = Some (clock.getTime())
816- listenerBus.post(SparkListenerStageCompleted (stageToInfos( stage) ))
783+ stage.info .completionTime = Some (clock.getTime())
784+ listenerBus.post(SparkListenerStageCompleted (stage.info ))
817785 runningStages -= stage
818786 }
819787 event.reason match {
@@ -822,18 +790,18 @@ class DAGScheduler(
822790 // TODO: fail the stage if the accumulator update fails...
823791 Accumulators .add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
824792 }
825- pendingTasks( stage) -= task
793+ stage.pendingTasks -= task
826794 task match {
827795 case rt : ResultTask [_, _] =>
828- resultStageToJob.get( stage) match {
796+ stage.resultOfJob match {
829797 case Some (job) =>
830798 if (! job.finished(rt.outputId)) {
831799 job.finished(rt.outputId) = true
832800 job.numFinished += 1
833801 // If the whole job has finished, remove it
834802 if (job.numFinished == job.numPartitions) {
835803 markStageAsFinished(stage)
836- cleanupStateForJobAndIndependentStages(job, Some (stage) )
804+ cleanupStateForJobAndIndependentStages(job)
837805 listenerBus.post(SparkListenerJobEnd (job.jobId, JobSucceeded ))
838806 }
839807
@@ -860,7 +828,7 @@ class DAGScheduler(
860828 } else {
861829 stage.addOutputLoc(smt.partitionId, status)
862830 }
863- if (runningStages.contains(stage) && pendingTasks( stage) .isEmpty) {
831+ if (runningStages.contains(stage) && stage.pendingTasks .isEmpty) {
864832 markStageAsFinished(stage)
865833 logInfo(" looking for newly runnable stages" )
866834 logInfo(" running: " + runningStages)
@@ -909,7 +877,7 @@ class DAGScheduler(
909877
910878 case Resubmitted =>
911879 logInfo(" Resubmitted " + task + " , so marking it as still running" )
912- pendingTasks( stage) += task
880+ stage.pendingTasks += task
913881
914882 case FetchFailed (bmAddress, shuffleId, mapId, reduceId) =>
915883 // Mark the stage that the reducer was in as unrunnable
@@ -994,13 +962,14 @@ class DAGScheduler(
994962 }
995963
996964 private [scheduler] def handleStageCancellation (stageId : Int ) {
997- if (stageIdToJobIds.contains(stageId)) {
998- val jobsThatUseStage : Array [Int ] = stageIdToJobIds(stageId).toArray
999- jobsThatUseStage.foreach(jobId => {
1000- handleJobCancellation(jobId, " because Stage %s was cancelled" .format(stageId))
1001- })
1002- } else {
1003- logInfo(" No active jobs to kill for Stage " + stageId)
965+ stageIdToStage.get(stageId) match {
966+ case Some (stage) =>
967+ val jobsThatUseStage : Array [Int ] = stage.jobIds.toArray
968+ jobsThatUseStage.foreach { jobId =>
969+ handleJobCancellation(jobId, s " because Stage $stageId was cancelled " )
970+ }
971+ case None =>
972+ logInfo(" No active jobs to kill for Stage " + stageId)
1004973 }
1005974 submitWaitingStages()
1006975 }
@@ -1009,8 +978,8 @@ class DAGScheduler(
1009978 if (! jobIdToStageIds.contains(jobId)) {
1010979 logDebug(" Trying to cancel unregistered job " + jobId)
1011980 } else {
1012- failJobAndIndependentStages(jobIdToActiveJob(jobId),
1013- " Job %d cancelled %s" .format(jobId, reason), None )
981+ failJobAndIndependentStages(
982+ jobIdToActiveJob(jobId), " Job %d cancelled %s" .format(jobId, reason))
1014983 }
1015984 submitWaitingStages()
1016985 }
@@ -1024,26 +993,21 @@ class DAGScheduler(
1024993 // Skip all the actions if the stage has been removed.
1025994 return
1026995 }
1027- val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
1028- stageToInfos(failedStage).completionTime = Some (clock.getTime())
1029- for (resultStage <- dependentStages) {
1030- val job = resultStageToJob(resultStage)
1031- failJobAndIndependentStages(job, s " Job aborted due to stage failure: $reason" ,
1032- Some (resultStage))
996+ val dependentJobs : Seq [ActiveJob ] =
997+ activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq
998+ failedStage.info.completionTime = Some (clock.getTime())
999+ for (job <- dependentJobs) {
1000+ failJobAndIndependentStages(job, s " Job aborted due to stage failure: $reason" )
10331001 }
1034- if (dependentStages .isEmpty) {
1002+ if (dependentJobs .isEmpty) {
10351003 logInfo(" Ignoring failure of " + failedStage + " because all jobs depending on it are done" )
10361004 }
10371005 }
10381006
10391007 /**
10401008 * Fails a job and all stages that are only used by that job, and cleans up relevant state.
1041- *
1042- * @param resultStage The result stage for the job, if known. Used to cleanup state for the job
1043- * slightly more efficiently than when not specified.
10441009 */
1045- private def failJobAndIndependentStages (job : ActiveJob , failureReason : String ,
1046- resultStage : Option [Stage ]) {
1010+ private def failJobAndIndependentStages (job : ActiveJob , failureReason : String ) {
10471011 val error = new SparkException (failureReason)
10481012 var ableToCancelStages = true
10491013
@@ -1057,7 +1021,7 @@ class DAGScheduler(
10571021 logError(" No stages registered for job " + job.jobId)
10581022 }
10591023 stages.foreach { stageId =>
1060- val jobsForStage = stageIdToJobIds .get(stageId)
1024+ val jobsForStage : Option [ HashSet [ Int ]] = stageIdToStage .get(stageId).map(_.jobIds )
10611025 if (jobsForStage.isEmpty || ! jobsForStage.get.contains(job.jobId)) {
10621026 logError(
10631027 " Job %d not registered for stage %d even though that stage was registered for the job"
@@ -1071,9 +1035,8 @@ class DAGScheduler(
10711035 if (runningStages.contains(stage)) {
10721036 try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
10731037 taskScheduler.cancelTasks(stageId, shouldInterruptThread)
1074- val stageInfo = stageToInfos(stage)
1075- stageInfo.stageFailed(failureReason)
1076- listenerBus.post(SparkListenerStageCompleted (stageToInfos(stage)))
1038+ stage.info.stageFailed(failureReason)
1039+ listenerBus.post(SparkListenerStageCompleted (stage.info))
10771040 } catch {
10781041 case e : UnsupportedOperationException =>
10791042 logInfo(s " Could not cancel tasks for stage $stageId" , e)
@@ -1086,7 +1049,7 @@ class DAGScheduler(
10861049
10871050 if (ableToCancelStages) {
10881051 job.listener.jobFailed(error)
1089- cleanupStateForJobAndIndependentStages(job, resultStage )
1052+ cleanupStateForJobAndIndependentStages(job)
10901053 listenerBus.post(SparkListenerJobEnd (job.jobId, JobFailed (error)))
10911054 }
10921055 }
0 commit comments