Skip to content

Commit d886ba3

Browse files
committed
[SPARK-23816][CORE] Killed tasks should ignore FetchFailures.
SPARK-19276 ensured that FetchFailures do not get swallowed by other layers of exception handling, but it also meant that a killed task could look like a fetch failure. This is particularly a problem with speculative execution, where we expect to kill tasks as they are reading shuffle data. The fix is to ensure that we always check for killed tasks first. Added a new unit test which fails before the fix, ran it 1k times to check for flakiness. Full suite of tests on jenkins.
1 parent cccaaa1 commit d886ba3

File tree

2 files changed

+74
-28
lines changed

2 files changed

+74
-28
lines changed

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,19 @@ private[spark] class Executor(
480480
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
481481

482482
} catch {
483+
case t: TaskKilledException =>
484+
logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
485+
setTaskFinishedAndClearInterruptStatus()
486+
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))
487+
488+
case _: InterruptedException | NonFatal(_) if
489+
task != null && task.reasonIfKilled.isDefined =>
490+
val killReason = task.reasonIfKilled.getOrElse("unknown reason")
491+
logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
492+
setTaskFinishedAndClearInterruptStatus()
493+
execBackend.statusUpdate(
494+
taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))
495+
483496
case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
484497
val reason = task.context.fetchFailed.get.toTaskFailedReason
485498
if (!t.isInstanceOf[FetchFailedException]) {
@@ -494,19 +507,6 @@ private[spark] class Executor(
494507
setTaskFinishedAndClearInterruptStatus()
495508
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
496509

497-
case t: TaskKilledException =>
498-
logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
499-
setTaskFinishedAndClearInterruptStatus()
500-
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))
501-
502-
case _: InterruptedException | NonFatal(_) if
503-
task != null && task.reasonIfKilled.isDefined =>
504-
val killReason = task.reasonIfKilled.getOrElse("unknown reason")
505-
logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
506-
setTaskFinishedAndClearInterruptStatus()
507-
execBackend.statusUpdate(
508-
taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))
509-
510510
case CausedBy(cDE: CommitDeniedException) =>
511511
val reason = cDE.toTaskCommitDeniedReason
512512
setTaskFinishedAndClearInterruptStatus()

core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
139139
// the fetch failure. The executor should still tell the driver that the task failed due to a
140140
// fetch failure, not a generic exception from user code.
141141
val inputRDD = new FetchFailureThrowingRDD(sc)
142-
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false)
142+
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false, interrupt = false)
143143
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
144144
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
145145
val task = new ResultTask(
@@ -173,8 +173,26 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
173173
}
174174

175175
test("SPARK-19276: OOMs correctly handled with a FetchFailure") {
176+
val (failReason, uncaughtExceptionHandler) = testFetchFailureHandling(true)
177+
assert(failReason.isInstanceOf[ExceptionFailure])
178+
val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
179+
verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
180+
assert(exceptionCaptor.getAllValues.size === 1)
181+
assert(exceptionCaptor.getAllValues().get(0).isInstanceOf[OutOfMemoryError])
182+
}
183+
184+
test(s"SPARK-23816: interrupts are not masked by a FetchFailure") {
185+
// If killing the task causes a fetch failure, we still treat it as a task that was killed,
186+
// as the fetch failure could easily be caused by interrupting the thread.
187+
val (failReason, _) = testFetchFailureHandling(false)
188+
assert(failReason.isInstanceOf[TaskKilled])
189+
}
190+
191+
def testFetchFailureHandling(oom: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = {
176192
// when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
177193
// may be a false positive. And we should call the uncaught exception handler.
194+
// SPARK-23816 also handle interrupts the same way, as killing an obsolete speculative task
195+
// does not represent a real fetch failure.
178196
val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
179197
sc = new SparkContext(conf)
180198
val serializer = SparkEnv.get.closureSerializer.newInstance()
@@ -183,7 +201,13 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
183201
// Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat
184202
// the fetch failure as a false positive, and just do normal OOM handling.
185203
val inputRDD = new FetchFailureThrowingRDD(sc)
186-
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true)
204+
if (!oom) {
205+
// we are trying to setup a case where a task is killed after a fetch failure -- this
206+
// is just a helper to coordinate between the task thread and this thread that will
207+
// kill the task
208+
ExecutorSuiteHelper.latches = new ExecutorSuiteHelper()
209+
}
210+
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = oom, interrupt = !oom)
187211
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
188212
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
189213
val task = new ResultTask(
@@ -200,15 +224,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
200224
val serTask = serializer.serialize(task)
201225
val taskDescription = createFakeTaskDescription(serTask)
202226

203-
val (failReason, uncaughtExceptionHandler) =
204-
runTaskGetFailReasonAndExceptionHandler(taskDescription)
205-
// make sure the task failure just looks like a OOM, not a fetch failure
206-
assert(failReason.isInstanceOf[ExceptionFailure])
207-
val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
208-
verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
209-
assert(exceptionCaptor.getAllValues.size === 1)
210-
assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError])
211-
}
227+
runTaskGetFailReasonAndExceptionHandler(taskDescription, killTask = !oom)
228+
}
212229

213230
test("Gracefully handle error in task deserialization") {
214231
val conf = new SparkConf
@@ -257,19 +274,32 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
257274
}
258275

259276
private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = {
260-
runTaskGetFailReasonAndExceptionHandler(taskDescription)._1
277+
runTaskGetFailReasonAndExceptionHandler(taskDescription, false)._1
261278
}
262279

263280
private def runTaskGetFailReasonAndExceptionHandler(
264-
taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = {
281+
taskDescription: TaskDescription,
282+
killTask: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = {
265283
val mockBackend = mock[ExecutorBackend]
266284
val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler]
267285
var executor: Executor = null
286+
var killingThread: Thread = null
268287
try {
269288
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true,
270289
uncaughtExceptionHandler = mockUncaughtExceptionHandler)
271290
// the task will be launched in a dedicated worker thread
272291
executor.launchTask(mockBackend, taskDescription)
292+
if (killTask) {
293+
killingThread = new Thread("kill-task") {
294+
override def run(): Unit = {
295+
// wait to kill the task until it has thrown a fetch failure
296+
ExecutorSuiteHelper.latches.latch1.await()
297+
// now we can kill the task
298+
executor.killAllTasks(true, "Killed task, eg. because of speculative execution")
299+
}
300+
}
301+
killingThread.start()
302+
}
273303
eventually(timeout(5.seconds), interval(10.milliseconds)) {
274304
assert(executor.numRunningTasks === 0)
275305
}
@@ -282,8 +312,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
282312
val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
283313
orderedMock.verify(mockBackend)
284314
.statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture())
315+
val finalState = if (killTask) TaskState.KILLED else TaskState.FAILED
285316
orderedMock.verify(mockBackend)
286-
.statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture())
317+
.statusUpdate(meq(0L), meq(finalState), statusCaptor.capture())
287318
// first statusUpdate for RUNNING has empty data
288319
assert(statusCaptor.getAllValues().get(0).remaining() === 0)
289320
// second update is more interesting
@@ -321,7 +352,8 @@ class SimplePartition extends Partition {
321352
class FetchFailureHidingRDD(
322353
sc: SparkContext,
323354
val input: FetchFailureThrowingRDD,
324-
throwOOM: Boolean) extends RDD[Int](input) {
355+
throwOOM: Boolean,
356+
interrupt: Boolean) extends RDD[Int](input) {
325357
override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
326358
val inItr = input.compute(split, context)
327359
try {
@@ -330,6 +362,15 @@ class FetchFailureHidingRDD(
330362
case t: Throwable =>
331363
if (throwOOM) {
332364
throw new OutOfMemoryError("OOM while handling another exception")
365+
} else if (interrupt) {
366+
// make sure our test is setup correctly
367+
assert(TaskContext.get().asInstanceOf[TaskContextImpl].fetchFailed.isDefined)
368+
// signal our test is ready for the task to get killed
369+
ExecutorSuiteHelper.latches.latch1.countDown()
370+
// then wait for another thread in the test to kill the task -- this latch
371+
// is never actually decremented, we just wait to get killed.
372+
ExecutorSuiteHelper.latches.latch2.await()
373+
throw new IllegalStateException("impossible")
333374
} else {
334375
throw new RuntimeException("User Exception that hides the original exception", t)
335376
}
@@ -352,6 +393,11 @@ private class ExecutorSuiteHelper {
352393
@volatile var testFailedReason: TaskFailedReason = _
353394
}
354395

396+
// helper for coordinating killing tasks
397+
private object ExecutorSuiteHelper {
398+
var latches: ExecutorSuiteHelper = null
399+
}
400+
355401
private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable {
356402
def writeExternal(out: ObjectOutput): Unit = {}
357403
def readExternal(in: ObjectInput): Unit = {

0 commit comments

Comments
 (0)