Skip to content

Commit e36c053

Browse files
alarxin
authored andcommitted
[SPARK-19549] Allow providing reason for stage/job cancelling
This change add an optional argument to `SparkContext.cancelStage()` and `SparkContext.cancelJob()` functions, which allows the caller to provide exact reason for the cancellation. Adds unit test. Author: Ala Luszczak <ala@databricks.com> Closes apache#16887 from ala/cancel.
1 parent 285915d commit e36c053

File tree

7 files changed

+138
-28
lines changed

7 files changed

+138
-28
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2062,10 +2062,32 @@ class SparkContext(config: SparkConf) extends Logging {
20622062
* Cancel a given job if it's scheduled or running.
20632063
*
20642064
* @param jobId the job ID to cancel
2065+
* @param reason optional reason for cancellation
20652066
* @note Throws `InterruptedException` if the cancel message cannot be sent
20662067
*/
2067-
def cancelJob(jobId: Int) {
2068-
dagScheduler.cancelJob(jobId)
2068+
def cancelJob(jobId: Int, reason: String): Unit = {
2069+
dagScheduler.cancelJob(jobId, Option(reason))
2070+
}
2071+
2072+
/**
2073+
* Cancel a given job if it's scheduled or running.
2074+
*
2075+
* @param jobId the job ID to cancel
2076+
* @note Throws `InterruptedException` if the cancel message cannot be sent
2077+
*/
2078+
def cancelJob(jobId: Int): Unit = {
2079+
dagScheduler.cancelJob(jobId, None)
2080+
}
2081+
2082+
/**
2083+
* Cancel a given stage and all jobs associated with it.
2084+
*
2085+
* @param stageId the stage ID to cancel
2086+
* @param reason reason for cancellation
2087+
* @note Throws `InterruptedException` if the cancel message cannot be sent
2088+
*/
2089+
def cancelStage(stageId: Int, reason: String): Unit = {
2090+
dagScheduler.cancelStage(stageId, Option(reason))
20692091
}
20702092

20712093
/**
@@ -2074,8 +2096,8 @@ class SparkContext(config: SparkConf) extends Logging {
20742096
* @param stageId the stage ID to cancel
20752097
* @note Throws `InterruptedException` if the cancel message cannot be sent
20762098
*/
2077-
def cancelStage(stageId: Int) {
2078-
dagScheduler.cancelStage(stageId)
2099+
def cancelStage(stageId: Int): Unit = {
2100+
dagScheduler.cancelStage(stageId, None)
20792101
}
20802102

20812103
/**

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -696,9 +696,9 @@ class DAGScheduler(
696696
/**
697697
* Cancel a job that is running or waiting in the queue.
698698
*/
699-
def cancelJob(jobId: Int): Unit = {
699+
def cancelJob(jobId: Int, reason: Option[String]): Unit = {
700700
logInfo("Asked to cancel job " + jobId)
701-
eventProcessLoop.post(JobCancelled(jobId))
701+
eventProcessLoop.post(JobCancelled(jobId, reason))
702702
}
703703

704704
/**
@@ -719,16 +719,16 @@ class DAGScheduler(
719719
private[scheduler] def doCancelAllJobs() {
720720
// Cancel all running jobs.
721721
runningStages.map(_.firstJobId).foreach(handleJobCancellation(_,
722-
reason = "as part of cancellation of all jobs"))
722+
Option("as part of cancellation of all jobs")))
723723
activeJobs.clear() // These should already be empty by this point,
724724
jobIdToActiveJob.clear() // but just in case we lost track of some jobs...
725725
}
726726

727727
/**
728728
* Cancel all jobs associated with a running or scheduled stage.
729729
*/
730-
def cancelStage(stageId: Int) {
731-
eventProcessLoop.post(StageCancelled(stageId))
730+
def cancelStage(stageId: Int, reason: Option[String]) {
731+
eventProcessLoop.post(StageCancelled(stageId, reason))
732732
}
733733

734734
/**
@@ -785,7 +785,8 @@ class DAGScheduler(
785785
}
786786
}
787787
val jobIds = activeInGroup.map(_.jobId)
788-
jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId)))
788+
jobIds.foreach(handleJobCancellation(_,
789+
Option("part of cancelled job group %s".format(groupId))))
789790
}
790791

791792
private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo) {
@@ -1356,24 +1357,30 @@ class DAGScheduler(
13561357
}
13571358
}
13581359

1359-
private[scheduler] def handleStageCancellation(stageId: Int) {
1360+
private[scheduler] def handleStageCancellation(stageId: Int, reason: Option[String]) {
13601361
stageIdToStage.get(stageId) match {
13611362
case Some(stage) =>
13621363
val jobsThatUseStage: Array[Int] = stage.jobIds.toArray
13631364
jobsThatUseStage.foreach { jobId =>
1364-
handleJobCancellation(jobId, s"because Stage $stageId was cancelled")
1365+
val reasonStr = reason match {
1366+
case Some(originalReason) =>
1367+
s"because $originalReason"
1368+
case None =>
1369+
s"because Stage $stageId was cancelled"
1370+
}
1371+
handleJobCancellation(jobId, Option(reasonStr))
13651372
}
13661373
case None =>
13671374
logInfo("No active jobs to kill for Stage " + stageId)
13681375
}
13691376
}
13701377

1371-
private[scheduler] def handleJobCancellation(jobId: Int, reason: String = "") {
1378+
private[scheduler] def handleJobCancellation(jobId: Int, reason: Option[String]) {
13721379
if (!jobIdToStageIds.contains(jobId)) {
13731380
logDebug("Trying to cancel unregistered job " + jobId)
13741381
} else {
13751382
failJobAndIndependentStages(
1376-
jobIdToActiveJob(jobId), "Job %d cancelled %s".format(jobId, reason))
1383+
jobIdToActiveJob(jobId), "Job %d cancelled %s".format(jobId, reason.getOrElse("")))
13771384
}
13781385
}
13791386

@@ -1615,11 +1622,11 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
16151622
case MapStageSubmitted(jobId, dependency, callSite, listener, properties) =>
16161623
dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite, listener, properties)
16171624

1618-
case StageCancelled(stageId) =>
1619-
dagScheduler.handleStageCancellation(stageId)
1625+
case StageCancelled(stageId, reason) =>
1626+
dagScheduler.handleStageCancellation(stageId, reason)
16201627

1621-
case JobCancelled(jobId) =>
1622-
dagScheduler.handleJobCancellation(jobId)
1628+
case JobCancelled(jobId, reason) =>
1629+
dagScheduler.handleJobCancellation(jobId, reason)
16231630

16241631
case JobGroupCancelled(groupId) =>
16251632
dagScheduler.handleJobGroupCancelled(groupId)

core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,15 @@ private[scheduler] case class MapStageSubmitted(
5353
properties: Properties = null)
5454
extends DAGSchedulerEvent
5555

56-
private[scheduler] case class StageCancelled(stageId: Int) extends DAGSchedulerEvent
56+
private[scheduler] case class StageCancelled(
57+
stageId: Int,
58+
reason: Option[String])
59+
extends DAGSchedulerEvent
5760

58-
private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent
61+
private[scheduler] case class JobCancelled(
62+
jobId: Int,
63+
reason: Option[String])
64+
extends DAGSchedulerEvent
5965

6066
private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent
6167

core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ private[spark] class JobWaiter[T](
5050
* will fail this job with a SparkException.
5151
*/
5252
def cancel() {
53-
dagScheduler.cancelJob(jobId)
53+
dagScheduler.cancelJob(jobId, None)
5454
}
5555

5656
override def taskSucceeded(index: Int, result: Any): Unit = {

core/src/test/scala/org/apache/spark/SparkContextSuite.scala

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,21 @@ import java.net.MalformedURLException
2222
import java.nio.charset.StandardCharsets
2323
import java.util.concurrent.TimeUnit
2424

25+
import scala.concurrent.duration._
2526
import scala.concurrent.Await
26-
import scala.concurrent.duration.Duration
2727

2828
import com.google.common.io.Files
2929
import org.apache.hadoop.io.{BytesWritable, LongWritable, Text}
3030
import org.apache.hadoop.mapred.TextInputFormat
3131
import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat}
32+
import org.scalatest.concurrent.Eventually
3233
import org.scalatest.Matchers._
3334

35+
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart}
3436
import org.apache.spark.util.Utils
3537

36-
class SparkContextSuite extends SparkFunSuite with LocalSparkContext {
38+
39+
class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventually {
3740

3841
test("Only one SparkContext may be active at a time") {
3942
// Regression test for SPARK-4180
@@ -451,4 +454,65 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext {
451454
sc.stop()
452455
}
453456
}
457+
458+
test("Cancelling stages/jobs with custom reasons.") {
459+
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
460+
val REASON = "You shall not pass"
461+
462+
val listener = new SparkListener {
463+
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
464+
if (SparkContextSuite.cancelStage) {
465+
eventually(timeout(10.seconds)) {
466+
assert(SparkContextSuite.isTaskStarted)
467+
}
468+
sc.cancelStage(taskStart.stageId, REASON)
469+
SparkContextSuite.cancelStage = false
470+
}
471+
}
472+
473+
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
474+
if (SparkContextSuite.cancelJob) {
475+
eventually(timeout(10.seconds)) {
476+
assert(SparkContextSuite.isTaskStarted)
477+
}
478+
sc.cancelJob(jobStart.jobId, REASON)
479+
SparkContextSuite.cancelJob = false
480+
}
481+
}
482+
}
483+
sc.addSparkListener(listener)
484+
485+
for (cancelWhat <- Seq("stage", "job")) {
486+
SparkContextSuite.isTaskStarted = false
487+
SparkContextSuite.cancelStage = (cancelWhat == "stage")
488+
SparkContextSuite.cancelJob = (cancelWhat == "job")
489+
490+
val ex = intercept[SparkException] {
491+
sc.range(0, 10000L).mapPartitions { x =>
492+
org.apache.spark.SparkContextSuite.isTaskStarted = true
493+
x
494+
}.cartesian(sc.range(0, 10L))count()
495+
}
496+
497+
ex.getCause() match {
498+
case null =>
499+
assert(ex.getMessage().contains(REASON))
500+
case cause: SparkException =>
501+
assert(cause.getMessage().contains(REASON))
502+
case cause: Throwable =>
503+
fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.")
504+
}
505+
506+
eventually(timeout(20.seconds)) {
507+
assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0)
508+
}
509+
}
510+
}
511+
512+
}
513+
514+
object SparkContextSuite {
515+
@volatile var cancelJob = false
516+
@volatile var cancelStage = false
517+
@volatile var isTaskStarted = false
454518
}

core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
329329

330330
/** Sends JobCancelled to the DAG scheduler. */
331331
private def cancel(jobId: Int) {
332-
runEvent(JobCancelled(jobId))
332+
runEvent(JobCancelled(jobId, None))
333333
}
334334

335335
test("[SPARK-3353] parent stage should have lower stage id") {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.test.SharedSQLContext
3131

3232

3333
class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventually {
34+
import testImplicits._
3435

3536
test("SPARK-7150 range api") {
3637
// numSlice is greater than length
@@ -137,25 +138,30 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall
137138
test("Cancelling stage in a query with Range.") {
138139
val listener = new SparkListener {
139140
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
140-
Thread.sleep(100)
141+
eventually(timeout(10.seconds)) {
142+
assert(DataFrameRangeSuite.isTaskStarted)
143+
}
141144
sparkContext.cancelStage(taskStart.stageId)
142145
}
143146
}
144147

145148
sparkContext.addSparkListener(listener)
146149
for (codegen <- Seq(true, false)) {
147150
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) {
151+
DataFrameRangeSuite.isTaskStarted = false
148152
val ex = intercept[SparkException] {
149-
spark.range(100000L).crossJoin(spark.range(100000L))
150-
.toDF("a", "b").agg(sum("a"), sum("b")).collect()
153+
spark.range(100000L).mapPartitions { x =>
154+
DataFrameRangeSuite.isTaskStarted = true
155+
x
156+
}.crossJoin(spark.range(100L)).toDF("a", "b").agg(sum("a"), sum("b")).collect()
151157
}
152158
ex.getCause() match {
153159
case null =>
154160
assert(ex.getMessage().contains("cancelled"))
155161
case cause: SparkException =>
156162
assert(cause.getMessage().contains("cancelled"))
157163
case cause: Throwable =>
158-
fail("Expected the casue to be SparkException, got " + cause.toString() + " instead.")
164+
fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.")
159165
}
160166
}
161167
eventually(timeout(20.seconds)) {
@@ -164,3 +170,8 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall
164170
}
165171
}
166172
}
173+
174+
object DataFrameRangeSuite {
175+
@volatile var isTaskStarted = false
176+
}
177+

0 commit comments

Comments
 (0)