Skip to content

Commit 9d8666c

Browse files
committed
Part of [SPARK-2456] Removed some HashMaps from DAGScheduler by storing information in Stage.
This is part of the scheduler cleanup/refactoring effort to make the scheduler code easier to maintain. @kayousterhout @markhamstra please take a look ... Author: Reynold Xin <rxin@apache.org> Closes #1561 from rxin/dagSchedulerHashMaps and squashes the following commits: 1c44e15 [Reynold Xin] Clear pending tasks in submitMissingTasks. 620a0d1 [Reynold Xin] Use filterKeys. 5b54404 [Reynold Xin] Code review feedback. c1e9a1c [Reynold Xin] Removed some HashMaps from DAGScheduler by storing information in Stage.
1 parent afd757a commit 9d8666c

File tree

3 files changed

+69
-97
lines changed

3 files changed

+69
-97
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 53 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,9 @@ class DAGScheduler(
8585
private val nextStageId = new AtomicInteger(0)
8686

8787
private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]]
88-
private[scheduler] val stageIdToJobIds = new HashMap[Int, HashSet[Int]]
8988
private[scheduler] val stageIdToStage = new HashMap[Int, Stage]
9089
private[scheduler] val shuffleToMapStage = new HashMap[Int, Stage]
9190
private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob]
92-
private[scheduler] val resultStageToJob = new HashMap[Stage, ActiveJob]
93-
private[scheduler] val stageToInfos = new HashMap[Stage, StageInfo]
9491

9592
// Stages we need to run whose parents aren't done
9693
private[scheduler] val waitingStages = new HashSet[Stage]
@@ -101,9 +98,6 @@ class DAGScheduler(
10198
// Stages that must be resubmitted due to fetch failures
10299
private[scheduler] val failedStages = new HashSet[Stage]
103100

104-
// Missing tasks from each stage
105-
private[scheduler] val pendingTasks = new HashMap[Stage, HashSet[Task[_]]]
106-
107101
private[scheduler] val activeJobs = new HashSet[ActiveJob]
108102

109103
// Contains the locations that each RDD's partitions are cached on
@@ -223,7 +217,6 @@ class DAGScheduler(
223217
new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
224218
stageIdToStage(id) = stage
225219
updateJobIdStageIdMaps(jobId, stage)
226-
stageToInfos(stage) = StageInfo.fromStage(stage)
227220
stage
228221
}
229222

@@ -315,13 +308,12 @@ class DAGScheduler(
315308
*/
316309
private def updateJobIdStageIdMaps(jobId: Int, stage: Stage) {
317310
def updateJobIdStageIdMapsList(stages: List[Stage]) {
318-
if (!stages.isEmpty) {
311+
if (stages.nonEmpty) {
319312
val s = stages.head
320-
stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId
313+
s.jobIds += jobId
321314
jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id
322-
val parents = getParentStages(s.rdd, jobId)
323-
val parentsWithoutThisJobId = parents.filter(p =>
324-
!stageIdToJobIds.get(p.id).exists(_.contains(jobId)))
315+
val parents: List[Stage] = getParentStages(s.rdd, jobId)
316+
val parentsWithoutThisJobId = parents.filter { ! _.jobIds.contains(jobId) }
325317
updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail)
326318
}
327319
}
@@ -333,16 +325,15 @@ class DAGScheduler(
333325
* handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks.
334326
*
335327
* @param job The job whose state to cleanup.
336-
* @param resultStage Specifies the result stage for the job; if set to None, this method
337-
* searches resultStagesToJob to find and cleanup the appropriate result stage.
338328
*/
339-
private def cleanupStateForJobAndIndependentStages(job: ActiveJob, resultStage: Option[Stage]) {
329+
private def cleanupStateForJobAndIndependentStages(job: ActiveJob) {
340330
val registeredStages = jobIdToStageIds.get(job.jobId)
341331
if (registeredStages.isEmpty || registeredStages.get.isEmpty) {
342332
logError("No stages registered for job " + job.jobId)
343333
} else {
344-
stageIdToJobIds.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach {
345-
case (stageId, jobSet) =>
334+
stageIdToStage.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach {
335+
case (stageId, stage) =>
336+
val jobSet = stage.jobIds
346337
if (!jobSet.contains(job.jobId)) {
347338
logError(
348339
"Job %d not registered for stage %d even though that stage was registered for the job"
@@ -355,14 +346,9 @@ class DAGScheduler(
355346
logDebug("Removing running stage %d".format(stageId))
356347
runningStages -= stage
357348
}
358-
stageToInfos -= stage
359349
for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) {
360350
shuffleToMapStage.remove(k)
361351
}
362-
if (pendingTasks.contains(stage) && !pendingTasks(stage).isEmpty) {
363-
logDebug("Removing pending status for stage %d".format(stageId))
364-
}
365-
pendingTasks -= stage
366352
if (waitingStages.contains(stage)) {
367353
logDebug("Removing stage %d from waiting set.".format(stageId))
368354
waitingStages -= stage
@@ -374,7 +360,6 @@ class DAGScheduler(
374360
}
375361
// data structures based on StageId
376362
stageIdToStage -= stageId
377-
stageIdToJobIds -= stageId
378363

379364
ShuffleMapTask.removeStage(stageId)
380365
ResultTask.removeStage(stageId)
@@ -393,19 +378,7 @@ class DAGScheduler(
393378
jobIdToStageIds -= job.jobId
394379
jobIdToActiveJob -= job.jobId
395380
activeJobs -= job
396-
397-
if (resultStage.isEmpty) {
398-
// Clean up result stages.
399-
val resultStagesForJob = resultStageToJob.keySet.filter(
400-
stage => resultStageToJob(stage).jobId == job.jobId)
401-
if (resultStagesForJob.size != 1) {
402-
logWarning(
403-
s"${resultStagesForJob.size} result stages for job ${job.jobId} (expect exactly 1)")
404-
}
405-
resultStageToJob --= resultStagesForJob
406-
} else {
407-
resultStageToJob -= resultStage.get
408-
}
381+
job.finalStage.resultOfJob = None
409382
}
410383

411384
/**
@@ -591,9 +564,10 @@ class DAGScheduler(
591564
job.listener.jobFailed(exception)
592565
} finally {
593566
val s = job.finalStage
594-
stageIdToJobIds -= s.id // clean up data structures that were populated for a local job,
595-
stageIdToStage -= s.id // but that won't get cleaned up via the normal paths through
596-
stageToInfos -= s // completion events or stage abort
567+
// clean up data structures that were populated for a local job,
568+
// but that won't get cleaned up via the normal paths through
569+
// completion events or stage abort
570+
stageIdToStage -= s.id
597571
jobIdToStageIds -= job.jobId
598572
listenerBus.post(SparkListenerJobEnd(job.jobId, jobResult))
599573
}
@@ -605,12 +579,8 @@ class DAGScheduler(
605579
// That should take care of at least part of the priority inversion problem with
606580
// cross-job dependencies.
607581
private def activeJobForStage(stage: Stage): Option[Int] = {
608-
if (stageIdToJobIds.contains(stage.id)) {
609-
val jobsThatUseStage: Array[Int] = stageIdToJobIds(stage.id).toArray.sorted
610-
jobsThatUseStage.find(jobIdToActiveJob.contains)
611-
} else {
612-
None
613-
}
582+
val jobsThatUseStage: Array[Int] = stage.jobIds.toArray.sorted
583+
jobsThatUseStage.find(jobIdToActiveJob.contains)
614584
}
615585

616586
private[scheduler] def handleJobGroupCancelled(groupId: String) {
@@ -642,9 +612,8 @@ class DAGScheduler(
642612
// is in the process of getting stopped.
643613
val stageFailedMessage = "Stage cancelled because SparkContext was shut down"
644614
runningStages.foreach { stage =>
645-
val info = stageToInfos(stage)
646-
info.stageFailed(stageFailedMessage)
647-
listenerBus.post(SparkListenerStageCompleted(info))
615+
stage.info.stageFailed(stageFailedMessage)
616+
listenerBus.post(SparkListenerStageCompleted(stage.info))
648617
}
649618
listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
650619
}
@@ -690,7 +659,7 @@ class DAGScheduler(
690659
} else {
691660
jobIdToActiveJob(jobId) = job
692661
activeJobs += job
693-
resultStageToJob(finalStage) = job
662+
finalStage.resultOfJob = Some(job)
694663
listenerBus.post(SparkListenerJobStart(job.jobId, jobIdToStageIds(jobId).toArray,
695664
properties))
696665
submitStage(finalStage)
@@ -727,8 +696,7 @@ class DAGScheduler(
727696
private def submitMissingTasks(stage: Stage, jobId: Int) {
728697
logDebug("submitMissingTasks(" + stage + ")")
729698
// Get our pending tasks and remember them in our pendingTasks entry
730-
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
731-
myPending.clear()
699+
stage.pendingTasks.clear()
732700
var tasks = ArrayBuffer[Task[_]]()
733701
if (stage.isShuffleMap) {
734702
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
@@ -737,7 +705,7 @@ class DAGScheduler(
737705
}
738706
} else {
739707
// This is a final stage; figure out its job's missing partitions
740-
val job = resultStageToJob(stage)
708+
val job = stage.resultOfJob.get
741709
for (id <- 0 until job.numPartitions if !job.finished(id)) {
742710
val partition = job.partitions(id)
743711
val locs = getPreferredLocs(stage.rdd, partition)
@@ -758,7 +726,7 @@ class DAGScheduler(
758726
// serializable. If tasks are not serializable, a SparkListenerStageCompleted event
759727
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
760728
// event.
761-
listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties))
729+
listenerBus.post(SparkListenerStageSubmitted(stage.info, properties))
762730

763731
// Preemptively serialize a task to make sure it can be serialized. We are catching this
764732
// exception here because it would be fairly hard to catch the non-serializable exception
@@ -778,11 +746,11 @@ class DAGScheduler(
778746
}
779747

780748
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
781-
myPending ++= tasks
782-
logDebug("New pending tasks: " + myPending)
749+
stage.pendingTasks ++= tasks
750+
logDebug("New pending tasks: " + stage.pendingTasks)
783751
taskScheduler.submitTasks(
784752
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
785-
stageToInfos(stage).submissionTime = Some(clock.getTime())
753+
stage.info.submissionTime = Some(clock.getTime())
786754
} else {
787755
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
788756
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@@ -807,13 +775,13 @@ class DAGScheduler(
807775
val stage = stageIdToStage(task.stageId)
808776

809777
def markStageAsFinished(stage: Stage) = {
810-
val serviceTime = stageToInfos(stage).submissionTime match {
778+
val serviceTime = stage.info.submissionTime match {
811779
case Some(t) => "%.03f".format((clock.getTime() - t) / 1000.0)
812780
case _ => "Unknown"
813781
}
814782
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
815-
stageToInfos(stage).completionTime = Some(clock.getTime())
816-
listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage)))
783+
stage.info.completionTime = Some(clock.getTime())
784+
listenerBus.post(SparkListenerStageCompleted(stage.info))
817785
runningStages -= stage
818786
}
819787
event.reason match {
@@ -822,18 +790,18 @@ class DAGScheduler(
822790
// TODO: fail the stage if the accumulator update fails...
823791
Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
824792
}
825-
pendingTasks(stage) -= task
793+
stage.pendingTasks -= task
826794
task match {
827795
case rt: ResultTask[_, _] =>
828-
resultStageToJob.get(stage) match {
796+
stage.resultOfJob match {
829797
case Some(job) =>
830798
if (!job.finished(rt.outputId)) {
831799
job.finished(rt.outputId) = true
832800
job.numFinished += 1
833801
// If the whole job has finished, remove it
834802
if (job.numFinished == job.numPartitions) {
835803
markStageAsFinished(stage)
836-
cleanupStateForJobAndIndependentStages(job, Some(stage))
804+
cleanupStateForJobAndIndependentStages(job)
837805
listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded))
838806
}
839807

@@ -860,7 +828,7 @@ class DAGScheduler(
860828
} else {
861829
stage.addOutputLoc(smt.partitionId, status)
862830
}
863-
if (runningStages.contains(stage) && pendingTasks(stage).isEmpty) {
831+
if (runningStages.contains(stage) && stage.pendingTasks.isEmpty) {
864832
markStageAsFinished(stage)
865833
logInfo("looking for newly runnable stages")
866834
logInfo("running: " + runningStages)
@@ -909,7 +877,7 @@ class DAGScheduler(
909877

910878
case Resubmitted =>
911879
logInfo("Resubmitted " + task + ", so marking it as still running")
912-
pendingTasks(stage) += task
880+
stage.pendingTasks += task
913881

914882
case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
915883
// Mark the stage that the reducer was in as unrunnable
@@ -994,13 +962,14 @@ class DAGScheduler(
994962
}
995963

996964
private[scheduler] def handleStageCancellation(stageId: Int) {
997-
if (stageIdToJobIds.contains(stageId)) {
998-
val jobsThatUseStage: Array[Int] = stageIdToJobIds(stageId).toArray
999-
jobsThatUseStage.foreach(jobId => {
1000-
handleJobCancellation(jobId, "because Stage %s was cancelled".format(stageId))
1001-
})
1002-
} else {
1003-
logInfo("No active jobs to kill for Stage " + stageId)
965+
stageIdToStage.get(stageId) match {
966+
case Some(stage) =>
967+
val jobsThatUseStage: Array[Int] = stage.jobIds.toArray
968+
jobsThatUseStage.foreach { jobId =>
969+
handleJobCancellation(jobId, s"because Stage $stageId was cancelled")
970+
}
971+
case None =>
972+
logInfo("No active jobs to kill for Stage " + stageId)
1004973
}
1005974
submitWaitingStages()
1006975
}
@@ -1009,8 +978,8 @@ class DAGScheduler(
1009978
if (!jobIdToStageIds.contains(jobId)) {
1010979
logDebug("Trying to cancel unregistered job " + jobId)
1011980
} else {
1012-
failJobAndIndependentStages(jobIdToActiveJob(jobId),
1013-
"Job %d cancelled %s".format(jobId, reason), None)
981+
failJobAndIndependentStages(
982+
jobIdToActiveJob(jobId), "Job %d cancelled %s".format(jobId, reason))
1014983
}
1015984
submitWaitingStages()
1016985
}
@@ -1024,26 +993,21 @@ class DAGScheduler(
1024993
// Skip all the actions if the stage has been removed.
1025994
return
1026995
}
1027-
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
1028-
stageToInfos(failedStage).completionTime = Some(clock.getTime())
1029-
for (resultStage <- dependentStages) {
1030-
val job = resultStageToJob(resultStage)
1031-
failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason",
1032-
Some(resultStage))
996+
val dependentJobs: Seq[ActiveJob] =
997+
activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq
998+
failedStage.info.completionTime = Some(clock.getTime())
999+
for (job <- dependentJobs) {
1000+
failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason")
10331001
}
1034-
if (dependentStages.isEmpty) {
1002+
if (dependentJobs.isEmpty) {
10351003
logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
10361004
}
10371005
}
10381006

10391007
/**
10401008
* Fails a job and all stages that are only used by that job, and cleans up relevant state.
1041-
*
1042-
* @param resultStage The result stage for the job, if known. Used to cleanup state for the job
1043-
* slightly more efficiently than when not specified.
10441009
*/
1045-
private def failJobAndIndependentStages(job: ActiveJob, failureReason: String,
1046-
resultStage: Option[Stage]) {
1010+
private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) {
10471011
val error = new SparkException(failureReason)
10481012
var ableToCancelStages = true
10491013

@@ -1057,7 +1021,7 @@ class DAGScheduler(
10571021
logError("No stages registered for job " + job.jobId)
10581022
}
10591023
stages.foreach { stageId =>
1060-
val jobsForStage = stageIdToJobIds.get(stageId)
1024+
val jobsForStage: Option[HashSet[Int]] = stageIdToStage.get(stageId).map(_.jobIds)
10611025
if (jobsForStage.isEmpty || !jobsForStage.get.contains(job.jobId)) {
10621026
logError(
10631027
"Job %d not registered for stage %d even though that stage was registered for the job"
@@ -1071,9 +1035,8 @@ class DAGScheduler(
10711035
if (runningStages.contains(stage)) {
10721036
try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
10731037
taskScheduler.cancelTasks(stageId, shouldInterruptThread)
1074-
val stageInfo = stageToInfos(stage)
1075-
stageInfo.stageFailed(failureReason)
1076-
listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage)))
1038+
stage.info.stageFailed(failureReason)
1039+
listenerBus.post(SparkListenerStageCompleted(stage.info))
10771040
} catch {
10781041
case e: UnsupportedOperationException =>
10791042
logInfo(s"Could not cancel tasks for stage $stageId", e)
@@ -1086,7 +1049,7 @@ class DAGScheduler(
10861049

10871050
if (ableToCancelStages) {
10881051
job.listener.jobFailed(error)
1089-
cleanupStateForJobAndIndependentStages(job, resultStage)
1052+
cleanupStateForJobAndIndependentStages(job)
10901053
listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
10911054
}
10921055
}

0 commit comments

Comments
 (0)