@@ -139,7 +139,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
139
139
// the fetch failure. The executor should still tell the driver that the task failed due to a
140
140
// fetch failure, not a generic exception from user code.
141
141
val inputRDD = new FetchFailureThrowingRDD (sc)
142
- val secondRDD = new FetchFailureHidingRDD (sc, inputRDD, throwOOM = false )
142
+ val secondRDD = new FetchFailureHidingRDD (sc, inputRDD, throwOOM = false , interrupt = false )
143
143
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
144
144
val serializedTaskMetrics = serializer.serialize(TaskMetrics .registered).array()
145
145
val task = new ResultTask (
@@ -173,8 +173,26 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
173
173
}
174
174
175
175
test(" SPARK-19276: OOMs correctly handled with a FetchFailure" ) {
176
+ val (failReason, uncaughtExceptionHandler) = testFetchFailureHandling(true )
177
+ assert(failReason.isInstanceOf [ExceptionFailure ])
178
+ val exceptionCaptor = ArgumentCaptor .forClass(classOf [Throwable ])
179
+ verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
180
+ assert(exceptionCaptor.getAllValues.size === 1 )
181
+ assert(exceptionCaptor.getAllValues().get(0 ).isInstanceOf [OutOfMemoryError ])
182
+ }
183
+
184
+ test(s " SPARK-23816: interrupts are not masked by a FetchFailure " ) {
185
+ // If killing the task causes a fetch failure, we still treat it as a task that was killed,
186
+ // as the fetch failure could easily be caused by interrupting the thread.
187
+ val (failReason, _) = testFetchFailureHandling(false )
188
+ assert(failReason.isInstanceOf [TaskKilled ])
189
+ }
190
+
191
+ def testFetchFailureHandling (oom : Boolean ): (TaskFailedReason , UncaughtExceptionHandler ) = {
176
192
// when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
177
193
// may be a false positive. And we should call the uncaught exception handler.
194
+ // SPARK-23816 also handle interrupts the same way, as killing an obsolete speculative task
195
+ // does not represent a real fetch failure.
178
196
val conf = new SparkConf ().setMaster(" local" ).setAppName(" executor suite test" )
179
197
sc = new SparkContext (conf)
180
198
val serializer = SparkEnv .get.closureSerializer.newInstance()
@@ -183,7 +201,13 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
183
201
// Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat
184
202
// the fetch failure as a false positive, and just do normal OOM handling.
185
203
val inputRDD = new FetchFailureThrowingRDD (sc)
186
- val secondRDD = new FetchFailureHidingRDD (sc, inputRDD, throwOOM = true )
204
+ if (! oom) {
205
+ // we are trying to setup a case where a task is killed after a fetch failure -- this
206
+ // is just a helper to coordinate between the task thread and this thread that will
207
+ // kill the task
208
+ ExecutorSuiteHelper .latches = new ExecutorSuiteHelper ()
209
+ }
210
+ val secondRDD = new FetchFailureHidingRDD (sc, inputRDD, throwOOM = oom, interrupt = ! oom)
187
211
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
188
212
val serializedTaskMetrics = serializer.serialize(TaskMetrics .registered).array()
189
213
val task = new ResultTask (
@@ -200,15 +224,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
200
224
val serTask = serializer.serialize(task)
201
225
val taskDescription = createFakeTaskDescription(serTask)
202
226
203
- val (failReason, uncaughtExceptionHandler) =
204
- runTaskGetFailReasonAndExceptionHandler(taskDescription)
205
- // make sure the task failure just looks like a OOM, not a fetch failure
206
- assert(failReason.isInstanceOf [ExceptionFailure ])
207
- val exceptionCaptor = ArgumentCaptor .forClass(classOf [Throwable ])
208
- verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
209
- assert(exceptionCaptor.getAllValues.size === 1 )
210
- assert(exceptionCaptor.getAllValues.get(0 ).isInstanceOf [OutOfMemoryError ])
211
- }
227
+ runTaskGetFailReasonAndExceptionHandler(taskDescription, killTask = ! oom)
228
+ }
212
229
213
230
test(" Gracefully handle error in task deserialization" ) {
214
231
val conf = new SparkConf
@@ -257,19 +274,32 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
257
274
}
258
275
259
276
private def runTaskAndGetFailReason (taskDescription : TaskDescription ): TaskFailedReason = {
260
- runTaskGetFailReasonAndExceptionHandler(taskDescription)._1
277
+ runTaskGetFailReasonAndExceptionHandler(taskDescription, false )._1
261
278
}
262
279
263
280
private def runTaskGetFailReasonAndExceptionHandler (
264
- taskDescription : TaskDescription ): (TaskFailedReason , UncaughtExceptionHandler ) = {
281
+ taskDescription : TaskDescription ,
282
+ killTask : Boolean ): (TaskFailedReason , UncaughtExceptionHandler ) = {
265
283
val mockBackend = mock[ExecutorBackend ]
266
284
val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler ]
267
285
var executor : Executor = null
286
+ var killingThread : Thread = null
268
287
try {
269
288
executor = new Executor (" id" , " localhost" , SparkEnv .get, userClassPath = Nil , isLocal = true ,
270
289
uncaughtExceptionHandler = mockUncaughtExceptionHandler)
271
290
// the task will be launched in a dedicated worker thread
272
291
executor.launchTask(mockBackend, taskDescription)
292
+ if (killTask) {
293
+ killingThread = new Thread (" kill-task" ) {
294
+ override def run (): Unit = {
295
+ // wait to kill the task until it has thrown a fetch failure
296
+ ExecutorSuiteHelper .latches.latch1.await()
297
+ // now we can kill the task
298
+ executor.killAllTasks(true , " Killed task, eg. because of speculative execution" )
299
+ }
300
+ }
301
+ killingThread.start()
302
+ }
273
303
eventually(timeout(5 .seconds), interval(10 .milliseconds)) {
274
304
assert(executor.numRunningTasks === 0 )
275
305
}
@@ -282,8 +312,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
282
312
val statusCaptor = ArgumentCaptor .forClass(classOf [ByteBuffer ])
283
313
orderedMock.verify(mockBackend)
284
314
.statusUpdate(meq(0L ), meq(TaskState .RUNNING ), statusCaptor.capture())
315
+ val finalState = if (killTask) TaskState .KILLED else TaskState .FAILED
285
316
orderedMock.verify(mockBackend)
286
- .statusUpdate(meq(0L ), meq(TaskState . FAILED ), statusCaptor.capture())
317
+ .statusUpdate(meq(0L ), meq(finalState ), statusCaptor.capture())
287
318
// first statusUpdate for RUNNING has empty data
288
319
assert(statusCaptor.getAllValues().get(0 ).remaining() === 0 )
289
320
// second update is more interesting
@@ -321,7 +352,8 @@ class SimplePartition extends Partition {
321
352
class FetchFailureHidingRDD (
322
353
sc : SparkContext ,
323
354
val input : FetchFailureThrowingRDD ,
324
- throwOOM : Boolean ) extends RDD [Int ](input) {
355
+ throwOOM : Boolean ,
356
+ interrupt : Boolean ) extends RDD [Int ](input) {
325
357
override def compute (split : Partition , context : TaskContext ): Iterator [Int ] = {
326
358
val inItr = input.compute(split, context)
327
359
try {
@@ -330,6 +362,15 @@ class FetchFailureHidingRDD(
330
362
case t : Throwable =>
331
363
if (throwOOM) {
332
364
throw new OutOfMemoryError (" OOM while handling another exception" )
365
+ } else if (interrupt) {
366
+ // make sure our test is setup correctly
367
+ assert(TaskContext .get().asInstanceOf [TaskContextImpl ].fetchFailed.isDefined)
368
+ // signal our test is ready for the task to get killed
369
+ ExecutorSuiteHelper .latches.latch1.countDown()
370
+ // then wait for another thread in the test to kill the task -- this latch
371
+ // is never actually decremented, we just wait to get killed.
372
+ ExecutorSuiteHelper .latches.latch2.await()
373
+ throw new IllegalStateException (" impossible" )
333
374
} else {
334
375
throw new RuntimeException (" User Exception that hides the original exception" , t)
335
376
}
@@ -352,6 +393,11 @@ private class ExecutorSuiteHelper {
352
393
@ volatile var testFailedReason : TaskFailedReason = _
353
394
}
354
395
396
+ // helper for coordinating killing tasks
397
+ private object ExecutorSuiteHelper {
398
+ var latches : ExecutorSuiteHelper = null
399
+ }
400
+
355
401
private class NonDeserializableTask extends FakeTask (0 , 0 ) with Externalizable {
356
402
def writeExternal (out : ObjectOutput ): Unit = {}
357
403
def readExternal (in : ObjectInput ): Unit = {
0 commit comments