@@ -22,19 +22,21 @@ import java.net.MalformedURLException
2222import java .nio .charset .StandardCharsets
2323import java .util .concurrent .TimeUnit
2424
25+ import scala .concurrent .duration ._
2526import scala .concurrent .Await
26- import scala .concurrent .duration .Duration
2727
2828import com .google .common .io .Files
2929import org .apache .hadoop .io .{BytesWritable , LongWritable , Text }
3030import org .apache .hadoop .mapred .TextInputFormat
3131import org .apache .hadoop .mapreduce .lib .input .{TextInputFormat => NewTextInputFormat }
32+ import org .scalatest .concurrent .Eventually
3233import org .scalatest .Matchers ._
3334
34- import org .apache .spark .scheduler .SparkListener
35+ import org .apache .spark .scheduler .{ SparkListener , SparkListenerJobStart , SparkListenerTaskStart }
3536import org .apache .spark .util .Utils
3637
37- class SparkContextSuite extends SparkFunSuite with LocalSparkContext {
38+
39+ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventually {
3840
3941 test(" Only one SparkContext may be active at a time" ) {
4042 // Regression test for SPARK-4180
@@ -465,4 +467,67 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext {
465467 assert(! sc.listenerBus.listeners.contains(sparkListener1))
466468 assert(sc.listenerBus.listeners.contains(sparkListener2))
467469 }
470+
471+ test(" Cancelling stages/jobs with custom reasons." ) {
472+ sc = new SparkContext (new SparkConf ().setAppName(" test" ).setMaster(" local" ))
473+ val REASON = " You shall not pass"
474+
475+ val listener = new SparkListener {
476+ override def onTaskStart (taskStart : SparkListenerTaskStart ): Unit = {
477+ if (SparkContextSuite .cancelStage) {
478+ eventually(timeout(10 .seconds)) {
479+ assert(SparkContextSuite .isTaskStarted)
480+ }
481+ sc.cancelStage(taskStart.stageId, REASON )
482+ SparkContextSuite .cancelStage = false
483+ }
484+ }
485+
486+ override def onJobStart (jobStart : SparkListenerJobStart ): Unit = {
487+ if (SparkContextSuite .cancelJob) {
488+ eventually(timeout(10 .seconds)) {
489+ assert(SparkContextSuite .isTaskStarted)
490+ }
491+ sc.cancelJob(jobStart.jobId, REASON )
492+ SparkContextSuite .cancelJob = false
493+ }
494+ }
495+ }
496+ sc.addSparkListener(listener)
497+
498+ for (cancelWhat <- Seq (" stage" , " job" )) {
499+ SparkContextSuite .isTaskStarted = false
500+ SparkContextSuite .cancelStage = (cancelWhat == " stage" )
501+ SparkContextSuite .cancelJob = (cancelWhat == " job" )
502+
503+ val ex = intercept[SparkException ] {
504+ sc.range(0 , 10000L ).map { x =>
505+ if (x == 10L ) {
506+ org.apache.spark.SparkContextSuite .isTaskStarted = true
507+ }
508+ x
509+ }.cartesian(sc.range(0 , 10L ))count()
510+ }
511+
512+ ex.getCause() match {
513+ case null =>
514+ assert(ex.getMessage().contains(REASON ))
515+ case cause : SparkException =>
516+ assert(cause.getMessage().contains(REASON ))
517+ case cause : Throwable =>
518+ fail(" Expected the cause to be SparkException, got " + cause.toString() + " instead." )
519+ }
520+
521+ eventually(timeout(20 .seconds)) {
522+ assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0 )
523+ }
524+ }
525+ }
526+
527+ }
528+
529+ object SparkContextSuite {
530+ @ volatile var cancelJob = false
531+ @ volatile var cancelStage = false
532+ @ volatile var isTaskStarted = false
468533}
0 commit comments