Skip to content

[SPARK-23816][CORE] Killed tasks should ignore FetchFailures. #20987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,19 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

} catch {
case t: TaskKilledException =>
logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))

case _: InterruptedException | NonFatal(_) if
task != null && task.reasonIfKilled.isDefined =>
val killReason = task.reasonIfKilled.getOrElse("unknown reason")
logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(
taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))

case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
val reason = task.context.fetchFailed.get.toTaskFailedReason
if (!t.isInstanceOf[FetchFailedException]) {
Expand All @@ -494,19 +507,6 @@ private[spark] class Executor(
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))

case t: TaskKilledException =>
logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))

case _: InterruptedException | NonFatal(_) if
task != null && task.reasonIfKilled.isDefined =>
val killReason = task.reasonIfKilled.getOrElse("unknown reason")
logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(
taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))

case CausedBy(cDE: CommitDeniedException) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should have seen this when reviewing the original change, my bad; thanks for fixing this @squito !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I shoudl have caught it too :) Kay mentioned OOM handling on the original pr, but we didn't think about interrupts.

val reason = cDE.toTaskCommitDeniedReason
setTaskFinishedAndClearInterruptStatus()
Expand Down
92 changes: 75 additions & 17 deletions core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler
import java.nio.ByteBuffer
import java.util.Properties
import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.mutable.Map
import scala.concurrent.duration._
Expand Down Expand Up @@ -139,7 +140,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
// the fetch failure. The executor should still tell the driver that the task failed due to a
// fetch failure, not a generic exception from user code.
val inputRDD = new FetchFailureThrowingRDD(sc)
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false)
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false, interrupt = false)
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
val task = new ResultTask(
Expand Down Expand Up @@ -173,17 +174,48 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
}

test("SPARK-19276: OOMs correctly handled with a FetchFailure") {
val (failReason, uncaughtExceptionHandler) = testFetchFailureHandling(true)
assert(failReason.isInstanceOf[ExceptionFailure])
val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
assert(exceptionCaptor.getAllValues.size === 1)
assert(exceptionCaptor.getAllValues().get(0).isInstanceOf[OutOfMemoryError])
}

test("SPARK-23816: interrupts are not masked by a FetchFailure") {
// If killing the task causes a fetch failure, we still treat it as a task that was killed,
// as the fetch failure could easily be caused by interrupting the thread.
val (failReason, _) = testFetchFailureHandling(false)
assert(failReason.isInstanceOf[TaskKilled])
}

/**
* Helper for testing some cases where a FetchFailure should *not* get sent back, because its
* superceded by another error, either an OOM or intentionally killing a task.
* @param oom if true, throw an OOM after the FetchFailure; else, interrupt the task after the
* FetchFailure
*/
private def testFetchFailureHandling(
oom: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = {
// when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
// may be a false positive. And we should call the uncaught exception handler.
// SPARK-23816 also handle interrupts the same way, as killing an obsolete speculative task
// does not represent a real fetch failure.
val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
sc = new SparkContext(conf)
val serializer = SparkEnv.get.closureSerializer.newInstance()
val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size

// Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat
// the fetch failure as a false positive, and just do normal OOM handling.
// Submit a job where a fetch failure is thrown, but then there is an OOM or interrupt. We
// should treat the fetch failure as a false positive, and do normal OOM or interrupt handling.
val inputRDD = new FetchFailureThrowingRDD(sc)
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true)
if (!oom) {
// we are trying to setup a case where a task is killed after a fetch failure -- this
// is just a helper to coordinate between the task thread and this thread that will
// kill the task
ExecutorSuiteHelper.latches = new ExecutorSuiteHelper()
}
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = oom, interrupt = !oom)
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
val task = new ResultTask(
Expand All @@ -200,15 +232,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
val serTask = serializer.serialize(task)
val taskDescription = createFakeTaskDescription(serTask)

val (failReason, uncaughtExceptionHandler) =
runTaskGetFailReasonAndExceptionHandler(taskDescription)
// make sure the task failure just looks like a OOM, not a fetch failure
assert(failReason.isInstanceOf[ExceptionFailure])
val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
assert(exceptionCaptor.getAllValues.size === 1)
assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError])
}
runTaskGetFailReasonAndExceptionHandler(taskDescription, killTask = !oom)
}

test("Gracefully handle error in task deserialization") {
val conf = new SparkConf
Expand Down Expand Up @@ -257,22 +282,39 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
}

private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = {
runTaskGetFailReasonAndExceptionHandler(taskDescription)._1
runTaskGetFailReasonAndExceptionHandler(taskDescription, false)._1
}

private def runTaskGetFailReasonAndExceptionHandler(
taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = {
taskDescription: TaskDescription,
killTask: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = {
val mockBackend = mock[ExecutorBackend]
val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler]
var executor: Executor = null
val timedOut = new AtomicBoolean(false)
try {
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true,
uncaughtExceptionHandler = mockUncaughtExceptionHandler)
// the task will be launched in a dedicated worker thread
executor.launchTask(mockBackend, taskDescription)
if (killTask) {
val killingThread = new Thread("kill-task") {
override def run(): Unit = {
// wait to kill the task until it has thrown a fetch failure
if (ExecutorSuiteHelper.latches.latch1.await(10, TimeUnit.SECONDS)) {
// now we can kill the task
executor.killAllTasks(true, "Killed task, eg. because of speculative execution")
} else {
timedOut.set(true)
}
}
}
killingThread.start()
}
eventually(timeout(5.seconds), interval(10.milliseconds)) {
assert(executor.numRunningTasks === 0)
}
assert(!timedOut.get(), "timed out waiting to be ready to kill tasks")
} finally {
if (executor != null) {
executor.stop()
Expand All @@ -282,8 +324,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
orderedMock.verify(mockBackend)
.statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture())
val finalState = if (killTask) TaskState.KILLED else TaskState.FAILED
orderedMock.verify(mockBackend)
.statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture())
.statusUpdate(meq(0L), meq(finalState), statusCaptor.capture())
// first statusUpdate for RUNNING has empty data
assert(statusCaptor.getAllValues().get(0).remaining() === 0)
// second update is more interesting
Expand Down Expand Up @@ -321,7 +364,8 @@ class SimplePartition extends Partition {
class FetchFailureHidingRDD(
sc: SparkContext,
val input: FetchFailureThrowingRDD,
throwOOM: Boolean) extends RDD[Int](input) {
throwOOM: Boolean,
interrupt: Boolean) extends RDD[Int](input) {
override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
val inItr = input.compute(split, context)
try {
Expand All @@ -330,6 +374,15 @@ class FetchFailureHidingRDD(
case t: Throwable =>
if (throwOOM) {
throw new OutOfMemoryError("OOM while handling another exception")
} else if (interrupt) {
// make sure our test is setup correctly
assert(TaskContext.get().asInstanceOf[TaskContextImpl].fetchFailed.isDefined)
// signal our test is ready for the task to get killed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, signal killingThread in our test, since killingThread is truly waiting on latch1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the original comment -- the mechanics of what is waiting on the latch are easy enough to follow, its more important to explain why.

ExecutorSuiteHelper.latches.latch1.countDown()
// then wait for another thread in the test to kill the task -- this latch
// is never actually decremented, we just wait to get killed.
ExecutorSuiteHelper.latches.latch2.await(10, TimeUnit.SECONDS)
throw new IllegalStateException("timed out waiting to be interrupted")
} else {
throw new RuntimeException("User Exception that hides the original exception", t)
}
Expand All @@ -352,6 +405,11 @@ private class ExecutorSuiteHelper {
@volatile var testFailedReason: TaskFailedReason = _
}

// helper for coordinating killing tasks
private object ExecutorSuiteHelper {
var latches: ExecutorSuiteHelper = null
}

private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable {
def writeExternal(out: ObjectOutput): Unit = {}
def readExternal(in: ObjectInput): Unit = {
Expand Down