Skip to content

Commit a9bf31f

Browse files
committed
wip
1 parent 8f7308f commit a9bf31f

File tree

12 files changed

+46
-17
lines changed

12 files changed

+46
-17
lines changed

core/src/main/scala/org/apache/spark/SparkException.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,13 @@ class SparkException(message: String, cause: Throwable)
3030
*/
3131
private[spark] class SparkDriverExecutionException(cause: Throwable)
3232
extends SparkException("Execution error", cause)
33+
34+
/**
35+
* Exception indicating an error internal to Spark -- it is in an inconsistent state, not due
36+
* to any error by the user
37+
*/
38+
class SparkIllegalStateException(message: String, cause: Throwable)
39+
extends SparkException(message, cause) {
40+
41+
def this(message: String) = this(message, null)
42+
}

core/src/main/scala/org/apache/spark/TaskContextImpl.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ private[spark] class TaskContextImpl(
3030
override val attemptNumber: Int,
3131
override val taskMemoryManager: TaskMemoryManager,
3232
val runningLocally: Boolean = false,
33+
val stageAttemptId: Int = 0, // for testing
3334
val taskMetrics: TaskMetrics = TaskMetrics.empty)
3435
extends TaskContext
3536
with Logging {

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,6 @@ class DAGScheduler(
834834
// Get our pending tasks and remember them in our pendingTasks entry
835835
stage.pendingTasks.clear()
836836

837-
838837
// First figure out the indexes of partition ids to compute.
839838
val partitionsToCompute: Seq[Int] = {
840839
stage match {
@@ -894,7 +893,7 @@ class DAGScheduler(
894893
partitionsToCompute.map { id =>
895894
val locs = getPreferredLocs(stage.rdd, id)
896895
val part = stage.rdd.partitions(id)
897-
new ShuffleMapTask(stage.id, taskBinary, part, locs)
896+
new ShuffleMapTask(stage.id, stage.attemptId, taskBinary, part, locs)
898897
}
899898

900899
case stage: ResultStage =>
@@ -903,7 +902,7 @@ class DAGScheduler(
903902
val p: Int = job.partitions(id)
904903
val part = stage.rdd.partitions(p)
905904
val locs = getPreferredLocs(stage.rdd, p)
906-
new ResultTask(stage.id, taskBinary, part, locs, id)
905+
new ResultTask(stage.id, stage.attemptId, taskBinary, part, locs, id)
907906
}
908907
}
909908

@@ -977,6 +976,7 @@ class DAGScheduler(
977976
val stageId = task.stageId
978977
val taskType = Utils.getFormattedClassName(task)
979978

979+
// REVIEWERS: does this need special handling for multiple completions of the same task?
980980
outputCommitCoordinator.taskCompleted(stageId, task.partitionId,
981981
event.taskInfo.attempt, event.reason)
982982

@@ -1039,10 +1039,11 @@ class DAGScheduler(
10391039
val execId = status.location.executorId
10401040
logDebug("ShuffleMapTask finished on " + execId)
10411041
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
1042-
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
1042+
logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
10431043
} else {
10441044
shuffleStage.addOutputLoc(smt.partitionId, status)
10451045
}
1046+
10461047
if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) {
10471048
markStageAsFinished(shuffleStage)
10481049
logInfo("looking for newly runnable stages")
@@ -1106,9 +1107,14 @@ class DAGScheduler(
11061107
// multiple tasks running concurrently on different executors). In that case, it is possible
11071108
// the fetch failure has already been handled by the scheduler.
11081109
if (runningStages.contains(failedStage)) {
1109-
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
1110-
s"due to a fetch failure from $mapStage (${mapStage.name})")
1111-
markStageAsFinished(failedStage, Some(failureMessage))
1110+
if (failedStage.attemptId - 1 > task.stageAttemptId) {
1111+
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
1112+
s" ${task.stageAttemptId}, which has already failed")
1113+
} else {
1114+
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
1115+
s"due to a fetch failure from $mapStage (${mapStage.name})")
1116+
markStageAsFinished(failedStage, Some(failureMessage))
1117+
}
11121118
}
11131119

11141120
if (disallowStageRetryForTest) {

core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD
4141
*/
4242
private[spark] class ResultTask[T, U](
4343
stageId: Int,
44+
stageAttemptId: Int,
4445
taskBinary: Broadcast[Array[Byte]],
4546
partition: Partition,
4647
@transient locs: Seq[TaskLocation],
4748
val outputId: Int)
48-
extends Task[U](stageId, partition.index) with Serializable {
49+
extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {
4950

5051
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
5152
if (locs == null) Nil else locs.toSet.toSeq

core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter
4040
*/
4141
private[spark] class ShuffleMapTask(
4242
stageId: Int,
43+
stageAttemptId: Int,
4344
taskBinary: Broadcast[Array[Byte]],
4445
partition: Partition,
4546
@transient private var locs: Seq[TaskLocation])
46-
extends Task[MapStatus](stageId, partition.index) with Logging {
47+
extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {
4748

4849
/** A constructor used only in test suites. This does not require passing in an RDD. */
4950
def this(partitionId: Int) {
50-
this(0, null, new Partition { override def index: Int = 0 }, null)
51+
this(0, 0, null, new Partition { override def index: Int = 0 }, null)
5152
}
5253

5354
@transient private val preferredLocs: Seq[TaskLocation] = {

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ import org.apache.spark.util.Utils
4343
* @param stageId id of the stage this task belongs to
4444
* @param partitionId index of the number in the RDD
4545
*/
46-
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
46+
private[spark] abstract class Task[T](
47+
val stageId: Int,
48+
val stageAttemptId: Int,
49+
var partitionId: Int) extends Serializable {
4750

4851
/**
4952
* Called by [[Executor]] to run this task.
@@ -55,6 +58,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
5558
final def run(taskAttemptId: Long, attemptNumber: Int): T = {
5659
context = new TaskContextImpl(
5760
stageId = stageId,
61+
stageAttemptId = stageAttemptId,
5862
partitionId = partitionId,
5963
taskAttemptId = taskAttemptId,
6064
attemptNumber = attemptNumber,

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,12 @@ private[spark] class TaskSchedulerImpl(
163163
this.synchronized {
164164
val manager = createTaskSetManager(taskSet, maxTaskFailures)
165165
activeTaskSets(taskSet.id) = manager
166+
val taskSetsPerStage = activeTaskSets.values.filterNot(_.isZombie).groupBy(_.stageId)
167+
taskSetsPerStage.foreach { case (stage, taskSets) =>
168+
if (taskSets.size > 1) {
169+
throw new SparkIllegalStateException("more than one active taskSet for stage " + stage)
170+
}
171+
}
166172
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
167173

168174
if (!isLocal && !hasReceivedTask) {

core/src/test/java/org/apache/spark/JavaAPISuite.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ public void persist() {
10111011
@Test
10121012
public void iterator() {
10131013
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
1014-
TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics());
1014+
TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, 0, new TaskMetrics());
10151015
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
10161016
}
10171017

core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
1919

2020
import org.apache.spark.TaskContext
2121

22-
class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
22+
class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) {
2323
override def runTask(context: TaskContext): Int = 0
2424

2525
override def preferredLocations: Seq[TaskLocation] = prefLocs

core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.TaskContext
2525
* A Task implementation that fails to serialize.
2626
*/
2727
private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
28-
extends Task[Array[Byte]](stageId, 0) {
28+
extends Task[Array[Byte]](stageId, 0, 0) {
2929

3030
override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
3131
override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()

core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
4141
}
4242
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
4343
val func = (c: TaskContext, i: Iterator[String]) => i.next()
44-
val task = new ResultTask[String, String](
45-
0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
44+
val task = new ResultTask[String, String](0, 0,
45+
sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
4646
intercept[RuntimeException] {
4747
task.run(0, 0)
4848
}

core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
135135
/**
136136
* A Task implementation that results in a large serialized task.
137137
*/
138-
class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) {
138+
class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) {
139139
val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024)
140140
val random = new Random(0)
141141
random.nextBytes(randomBuffer)

0 commit comments

Comments
 (0)