@@ -22,6 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler
22
22
import java .nio .ByteBuffer
23
23
import java .util .Properties
24
24
import java .util .concurrent .{CountDownLatch , TimeUnit }
25
+ import java .util .concurrent .atomic .AtomicBoolean
25
26
26
27
import scala .collection .mutable .Map
27
28
import scala .concurrent .duration ._
@@ -139,7 +140,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
139
140
// the fetch failure. The executor should still tell the driver that the task failed due to a
140
141
// fetch failure, not a generic exception from user code.
141
142
val inputRDD = new FetchFailureThrowingRDD (sc)
142
- val secondRDD = new FetchFailureHidingRDD (sc, inputRDD, throwOOM = false )
143
+ val secondRDD = new FetchFailureHidingRDD (sc, inputRDD, throwOOM = false , interrupt = false )
143
144
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
144
145
val serializedTaskMetrics = serializer.serialize(TaskMetrics .registered).array()
145
146
val task = new ResultTask (
@@ -173,17 +174,48 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
173
174
}
174
175
175
176
test(" SPARK-19276: OOMs correctly handled with a FetchFailure" ) {
177
+ val (failReason, uncaughtExceptionHandler) = testFetchFailureHandling(true )
178
+ assert(failReason.isInstanceOf [ExceptionFailure ])
179
+ val exceptionCaptor = ArgumentCaptor .forClass(classOf [Throwable ])
180
+ verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
181
+ assert(exceptionCaptor.getAllValues.size === 1 )
182
+ assert(exceptionCaptor.getAllValues().get(0 ).isInstanceOf [OutOfMemoryError ])
183
+ }
184
+
185
+ test(" SPARK-23816: interrupts are not masked by a FetchFailure" ) {
186
+ // If killing the task causes a fetch failure, we still treat it as a task that was killed,
187
+ // as the fetch failure could easily be caused by interrupting the thread.
188
+ val (failReason, _) = testFetchFailureHandling(false )
189
+ assert(failReason.isInstanceOf [TaskKilled ])
190
+ }
191
+
192
+ /**
193
+ * Helper for testing some cases where a FetchFailure should *not* get sent back, because its
194
+ * superceded by another error, either an OOM or intentionally killing a task.
195
+ * @param oom if true, throw an OOM after the FetchFailure; else, interrupt the task after the
196
+ * FetchFailure
197
+ */
198
+ private def testFetchFailureHandling (
199
+ oom : Boolean ): (TaskFailedReason , UncaughtExceptionHandler ) = {
176
200
// when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
177
201
// may be a false positive. And we should call the uncaught exception handler.
202
+ // SPARK-23816 also handle interrupts the same way, as killing an obsolete speculative task
203
+ // does not represent a real fetch failure.
178
204
val conf = new SparkConf ().setMaster(" local" ).setAppName(" executor suite test" )
179
205
sc = new SparkContext (conf)
180
206
val serializer = SparkEnv .get.closureSerializer.newInstance()
181
207
val resultFunc = (context : TaskContext , itr : Iterator [Int ]) => itr.size
182
208
183
- // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat
184
- // the fetch failure as a false positive, and just do normal OOM handling.
209
+ // Submit a job where a fetch failure is thrown, but then there is an OOM or interrupt. We
210
+ // should treat the fetch failure as a false positive, and do normal OOM or interrupt handling.
185
211
val inputRDD = new FetchFailureThrowingRDD (sc)
186
- val secondRDD = new FetchFailureHidingRDD (sc, inputRDD, throwOOM = true )
212
+ if (! oom) {
213
+ // we are trying to setup a case where a task is killed after a fetch failure -- this
214
+ // is just a helper to coordinate between the task thread and this thread that will
215
+ // kill the task
216
+ ExecutorSuiteHelper .latches = new ExecutorSuiteHelper ()
217
+ }
218
+ val secondRDD = new FetchFailureHidingRDD (sc, inputRDD, throwOOM = oom, interrupt = ! oom)
187
219
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
188
220
val serializedTaskMetrics = serializer.serialize(TaskMetrics .registered).array()
189
221
val task = new ResultTask (
@@ -200,15 +232,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
200
232
val serTask = serializer.serialize(task)
201
233
val taskDescription = createFakeTaskDescription(serTask)
202
234
203
- val (failReason, uncaughtExceptionHandler) =
204
- runTaskGetFailReasonAndExceptionHandler(taskDescription)
205
- // make sure the task failure just looks like a OOM, not a fetch failure
206
- assert(failReason.isInstanceOf [ExceptionFailure ])
207
- val exceptionCaptor = ArgumentCaptor .forClass(classOf [Throwable ])
208
- verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
209
- assert(exceptionCaptor.getAllValues.size === 1 )
210
- assert(exceptionCaptor.getAllValues.get(0 ).isInstanceOf [OutOfMemoryError ])
211
- }
235
+ runTaskGetFailReasonAndExceptionHandler(taskDescription, killTask = ! oom)
236
+ }
212
237
213
238
test(" Gracefully handle error in task deserialization" ) {
214
239
val conf = new SparkConf
@@ -257,22 +282,39 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
257
282
}
258
283
259
284
private def runTaskAndGetFailReason (taskDescription : TaskDescription ): TaskFailedReason = {
260
- runTaskGetFailReasonAndExceptionHandler(taskDescription)._1
285
+ runTaskGetFailReasonAndExceptionHandler(taskDescription, false )._1
261
286
}
262
287
263
288
private def runTaskGetFailReasonAndExceptionHandler (
264
- taskDescription : TaskDescription ): (TaskFailedReason , UncaughtExceptionHandler ) = {
289
+ taskDescription : TaskDescription ,
290
+ killTask : Boolean ): (TaskFailedReason , UncaughtExceptionHandler ) = {
265
291
val mockBackend = mock[ExecutorBackend ]
266
292
val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler ]
267
293
var executor : Executor = null
294
+ val timedOut = new AtomicBoolean (false )
268
295
try {
269
296
executor = new Executor (" id" , " localhost" , SparkEnv .get, userClassPath = Nil , isLocal = true ,
270
297
uncaughtExceptionHandler = mockUncaughtExceptionHandler)
271
298
// the task will be launched in a dedicated worker thread
272
299
executor.launchTask(mockBackend, taskDescription)
300
+ if (killTask) {
301
+ val killingThread = new Thread (" kill-task" ) {
302
+ override def run (): Unit = {
303
+ // wait to kill the task until it has thrown a fetch failure
304
+ if (ExecutorSuiteHelper .latches.latch1.await(10 , TimeUnit .SECONDS )) {
305
+ // now we can kill the task
306
+ executor.killAllTasks(true , " Killed task, eg. because of speculative execution" )
307
+ } else {
308
+ timedOut.set(true )
309
+ }
310
+ }
311
+ }
312
+ killingThread.start()
313
+ }
273
314
eventually(timeout(5 .seconds), interval(10 .milliseconds)) {
274
315
assert(executor.numRunningTasks === 0 )
275
316
}
317
+ assert(! timedOut.get(), " timed out waiting to be ready to kill tasks" )
276
318
} finally {
277
319
if (executor != null ) {
278
320
executor.stop()
@@ -282,8 +324,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
282
324
val statusCaptor = ArgumentCaptor .forClass(classOf [ByteBuffer ])
283
325
orderedMock.verify(mockBackend)
284
326
.statusUpdate(meq(0L ), meq(TaskState .RUNNING ), statusCaptor.capture())
327
+ val finalState = if (killTask) TaskState .KILLED else TaskState .FAILED
285
328
orderedMock.verify(mockBackend)
286
- .statusUpdate(meq(0L ), meq(TaskState . FAILED ), statusCaptor.capture())
329
+ .statusUpdate(meq(0L ), meq(finalState ), statusCaptor.capture())
287
330
// first statusUpdate for RUNNING has empty data
288
331
assert(statusCaptor.getAllValues().get(0 ).remaining() === 0 )
289
332
// second update is more interesting
@@ -321,7 +364,8 @@ class SimplePartition extends Partition {
321
364
class FetchFailureHidingRDD (
322
365
sc : SparkContext ,
323
366
val input : FetchFailureThrowingRDD ,
324
- throwOOM : Boolean ) extends RDD [Int ](input) {
367
+ throwOOM : Boolean ,
368
+ interrupt : Boolean ) extends RDD [Int ](input) {
325
369
override def compute (split : Partition , context : TaskContext ): Iterator [Int ] = {
326
370
val inItr = input.compute(split, context)
327
371
try {
@@ -330,6 +374,15 @@ class FetchFailureHidingRDD(
330
374
case t : Throwable =>
331
375
if (throwOOM) {
332
376
throw new OutOfMemoryError (" OOM while handling another exception" )
377
+ } else if (interrupt) {
378
+ // make sure our test is setup correctly
379
+ assert(TaskContext .get().asInstanceOf [TaskContextImpl ].fetchFailed.isDefined)
380
+ // signal our test is ready for the task to get killed
381
+ ExecutorSuiteHelper .latches.latch1.countDown()
382
+ // then wait for another thread in the test to kill the task -- this latch
383
+ // is never actually decremented, we just wait to get killed.
384
+ ExecutorSuiteHelper .latches.latch2.await(10 , TimeUnit .SECONDS )
385
+ throw new IllegalStateException (" timed out waiting to be interrupted" )
333
386
} else {
334
387
throw new RuntimeException (" User Exception that hides the original exception" , t)
335
388
}
@@ -352,6 +405,11 @@ private class ExecutorSuiteHelper {
352
405
@ volatile var testFailedReason : TaskFailedReason = _
353
406
}
354
407
408
+ // helper for coordinating killing tasks
409
+ private object ExecutorSuiteHelper {
410
+ var latches : ExecutorSuiteHelper = null
411
+ }
412
+
355
413
private class NonDeserializableTask extends FakeTask (0 , 0 ) with Externalizable {
356
414
def writeExternal (out : ObjectOutput ): Unit = {}
357
415
def readExternal (in : ObjectInput ): Unit = {
0 commit comments