Skip to content

Commit d785217

Browse files
alarxin
authored andcommitted
[SPARK-19549] Allow providing reason for stage/job cancelling
## What changes were proposed in this pull request? This change add an optional argument to `SparkContext.cancelStage()` and `SparkContext.cancelJob()` functions, which allows the caller to provide exact reason for the cancellation. ## How was this patch tested? Adds unit test. Author: Ala Luszczak <ala@databricks.com> Closes #16887 from ala/cancel.
1 parent 3a43ae7 commit d785217

File tree

7 files changed

+138
-29
lines changed

7 files changed

+138
-29
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2207,10 +2207,32 @@ class SparkContext(config: SparkConf) extends Logging {
22072207
* Cancel a given job if it's scheduled or running.
22082208
*
22092209
* @param jobId the job ID to cancel
2210+
* @param reason optional reason for cancellation
22102211
* @note Throws `InterruptedException` if the cancel message cannot be sent
22112212
*/
2212-
def cancelJob(jobId: Int) {
2213-
dagScheduler.cancelJob(jobId)
2213+
def cancelJob(jobId: Int, reason: String): Unit = {
2214+
dagScheduler.cancelJob(jobId, Option(reason))
2215+
}
2216+
2217+
/**
2218+
* Cancel a given job if it's scheduled or running.
2219+
*
2220+
* @param jobId the job ID to cancel
2221+
* @note Throws `InterruptedException` if the cancel message cannot be sent
2222+
*/
2223+
def cancelJob(jobId: Int): Unit = {
2224+
dagScheduler.cancelJob(jobId, None)
2225+
}
2226+
2227+
/**
2228+
* Cancel a given stage and all jobs associated with it.
2229+
*
2230+
* @param stageId the stage ID to cancel
2231+
* @param reason reason for cancellation
2232+
* @note Throws `InterruptedException` if the cancel message cannot be sent
2233+
*/
2234+
def cancelStage(stageId: Int, reason: String): Unit = {
2235+
dagScheduler.cancelStage(stageId, Option(reason))
22142236
}
22152237

22162238
/**
@@ -2219,8 +2241,8 @@ class SparkContext(config: SparkConf) extends Logging {
22192241
* @param stageId the stage ID to cancel
22202242
* @note Throws `InterruptedException` if the cancel message cannot be sent
22212243
*/
2222-
def cancelStage(stageId: Int) {
2223-
dagScheduler.cancelStage(stageId)
2244+
def cancelStage(stageId: Int): Unit = {
2245+
dagScheduler.cancelStage(stageId, None)
22242246
}
22252247

22262248
/**

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -696,9 +696,9 @@ class DAGScheduler(
696696
/**
697697
* Cancel a job that is running or waiting in the queue.
698698
*/
699-
def cancelJob(jobId: Int): Unit = {
699+
def cancelJob(jobId: Int, reason: Option[String]): Unit = {
700700
logInfo("Asked to cancel job " + jobId)
701-
eventProcessLoop.post(JobCancelled(jobId))
701+
eventProcessLoop.post(JobCancelled(jobId, reason))
702702
}
703703

704704
/**
@@ -719,16 +719,16 @@ class DAGScheduler(
719719
private[scheduler] def doCancelAllJobs() {
720720
// Cancel all running jobs.
721721
runningStages.map(_.firstJobId).foreach(handleJobCancellation(_,
722-
reason = "as part of cancellation of all jobs"))
722+
Option("as part of cancellation of all jobs")))
723723
activeJobs.clear() // These should already be empty by this point,
724724
jobIdToActiveJob.clear() // but just in case we lost track of some jobs...
725725
}
726726

727727
/**
728728
* Cancel all jobs associated with a running or scheduled stage.
729729
*/
730-
def cancelStage(stageId: Int) {
731-
eventProcessLoop.post(StageCancelled(stageId))
730+
def cancelStage(stageId: Int, reason: Option[String]) {
731+
eventProcessLoop.post(StageCancelled(stageId, reason))
732732
}
733733

734734
/**
@@ -785,7 +785,8 @@ class DAGScheduler(
785785
}
786786
}
787787
val jobIds = activeInGroup.map(_.jobId)
788-
jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId)))
788+
jobIds.foreach(handleJobCancellation(_,
789+
Option("part of cancelled job group %s".format(groupId))))
789790
}
790791

791792
private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo) {
@@ -1377,24 +1378,30 @@ class DAGScheduler(
13771378
}
13781379
}
13791380

1380-
private[scheduler] def handleStageCancellation(stageId: Int) {
1381+
private[scheduler] def handleStageCancellation(stageId: Int, reason: Option[String]) {
13811382
stageIdToStage.get(stageId) match {
13821383
case Some(stage) =>
13831384
val jobsThatUseStage: Array[Int] = stage.jobIds.toArray
13841385
jobsThatUseStage.foreach { jobId =>
1385-
handleJobCancellation(jobId, s"because Stage $stageId was cancelled")
1386+
val reasonStr = reason match {
1387+
case Some(originalReason) =>
1388+
s"because $originalReason"
1389+
case None =>
1390+
s"because Stage $stageId was cancelled"
1391+
}
1392+
handleJobCancellation(jobId, Option(reasonStr))
13861393
}
13871394
case None =>
13881395
logInfo("No active jobs to kill for Stage " + stageId)
13891396
}
13901397
}
13911398

1392-
private[scheduler] def handleJobCancellation(jobId: Int, reason: String = "") {
1399+
private[scheduler] def handleJobCancellation(jobId: Int, reason: Option[String]) {
13931400
if (!jobIdToStageIds.contains(jobId)) {
13941401
logDebug("Trying to cancel unregistered job " + jobId)
13951402
} else {
13961403
failJobAndIndependentStages(
1397-
jobIdToActiveJob(jobId), "Job %d cancelled %s".format(jobId, reason))
1404+
jobIdToActiveJob(jobId), "Job %d cancelled %s".format(jobId, reason.getOrElse("")))
13981405
}
13991406
}
14001407

@@ -1636,11 +1643,11 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
16361643
case MapStageSubmitted(jobId, dependency, callSite, listener, properties) =>
16371644
dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite, listener, properties)
16381645

1639-
case StageCancelled(stageId) =>
1640-
dagScheduler.handleStageCancellation(stageId)
1646+
case StageCancelled(stageId, reason) =>
1647+
dagScheduler.handleStageCancellation(stageId, reason)
16411648

1642-
case JobCancelled(jobId) =>
1643-
dagScheduler.handleJobCancellation(jobId)
1649+
case JobCancelled(jobId, reason) =>
1650+
dagScheduler.handleJobCancellation(jobId, reason)
16441651

16451652
case JobGroupCancelled(groupId) =>
16461653
dagScheduler.handleJobGroupCancelled(groupId)

core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,15 @@ private[scheduler] case class MapStageSubmitted(
5353
properties: Properties = null)
5454
extends DAGSchedulerEvent
5555

56-
private[scheduler] case class StageCancelled(stageId: Int) extends DAGSchedulerEvent
56+
private[scheduler] case class StageCancelled(
57+
stageId: Int,
58+
reason: Option[String])
59+
extends DAGSchedulerEvent
5760

58-
private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent
61+
private[scheduler] case class JobCancelled(
62+
jobId: Int,
63+
reason: Option[String])
64+
extends DAGSchedulerEvent
5965

6066
private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent
6167

core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ private[spark] class JobWaiter[T](
5050
* will fail this job with a SparkException.
5151
*/
5252
def cancel() {
53-
dagScheduler.cancelJob(jobId)
53+
dagScheduler.cancelJob(jobId, None)
5454
}
5555

5656
override def taskSucceeded(index: Int, result: Any): Unit = {

core/src/test/scala/org/apache/spark/SparkContextSuite.scala

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,21 @@ import java.net.MalformedURLException
2222
import java.nio.charset.StandardCharsets
2323
import java.util.concurrent.TimeUnit
2424

25+
import scala.concurrent.duration._
2526
import scala.concurrent.Await
26-
import scala.concurrent.duration.Duration
2727

2828
import com.google.common.io.Files
2929
import org.apache.hadoop.io.{BytesWritable, LongWritable, Text}
3030
import org.apache.hadoop.mapred.TextInputFormat
3131
import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat}
32+
import org.scalatest.concurrent.Eventually
3233
import org.scalatest.Matchers._
3334

34-
import org.apache.spark.scheduler.SparkListener
35+
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart}
3536
import org.apache.spark.util.Utils
3637

37-
class SparkContextSuite extends SparkFunSuite with LocalSparkContext {
38+
39+
class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventually {
3840

3941
test("Only one SparkContext may be active at a time") {
4042
// Regression test for SPARK-4180
@@ -465,4 +467,65 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext {
465467
assert(!sc.listenerBus.listeners.contains(sparkListener1))
466468
assert(sc.listenerBus.listeners.contains(sparkListener2))
467469
}
470+
471+
test("Cancelling stages/jobs with custom reasons.") {
472+
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
473+
val REASON = "You shall not pass"
474+
475+
val listener = new SparkListener {
476+
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
477+
if (SparkContextSuite.cancelStage) {
478+
eventually(timeout(10.seconds)) {
479+
assert(SparkContextSuite.isTaskStarted)
480+
}
481+
sc.cancelStage(taskStart.stageId, REASON)
482+
SparkContextSuite.cancelStage = false
483+
}
484+
}
485+
486+
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
487+
if (SparkContextSuite.cancelJob) {
488+
eventually(timeout(10.seconds)) {
489+
assert(SparkContextSuite.isTaskStarted)
490+
}
491+
sc.cancelJob(jobStart.jobId, REASON)
492+
SparkContextSuite.cancelJob = false
493+
}
494+
}
495+
}
496+
sc.addSparkListener(listener)
497+
498+
for (cancelWhat <- Seq("stage", "job")) {
499+
SparkContextSuite.isTaskStarted = false
500+
SparkContextSuite.cancelStage = (cancelWhat == "stage")
501+
SparkContextSuite.cancelJob = (cancelWhat == "job")
502+
503+
val ex = intercept[SparkException] {
504+
sc.range(0, 10000L).mapPartitions { x =>
505+
org.apache.spark.SparkContextSuite.isTaskStarted = true
506+
x
507+
}.cartesian(sc.range(0, 10L))count()
508+
}
509+
510+
ex.getCause() match {
511+
case null =>
512+
assert(ex.getMessage().contains(REASON))
513+
case cause: SparkException =>
514+
assert(cause.getMessage().contains(REASON))
515+
case cause: Throwable =>
516+
fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.")
517+
}
518+
519+
eventually(timeout(20.seconds)) {
520+
assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0)
521+
}
522+
}
523+
}
524+
525+
}
526+
527+
object SparkContextSuite {
528+
@volatile var cancelJob = false
529+
@volatile var cancelStage = false
530+
@volatile var isTaskStarted = false
468531
}

core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
329329

330330
/** Sends JobCancelled to the DAG scheduler. */
331331
private def cancel(jobId: Int) {
332-
runEvent(JobCancelled(jobId))
332+
runEvent(JobCancelled(jobId, None))
333333
}
334334

335335
test("[SPARK-3353] parent stage should have lower stage id") {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.test.SharedSQLContext
3131

3232

3333
class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventually {
34+
import testImplicits._
3435

3536
test("SPARK-7150 range api") {
3637
// numSlice is greater than length
@@ -137,25 +138,30 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall
137138
test("Cancelling stage in a query with Range.") {
138139
val listener = new SparkListener {
139140
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
140-
Thread.sleep(100)
141+
eventually(timeout(10.seconds)) {
142+
assert(DataFrameRangeSuite.isTaskStarted)
143+
}
141144
sparkContext.cancelStage(taskStart.stageId)
142145
}
143146
}
144147

145148
sparkContext.addSparkListener(listener)
146149
for (codegen <- Seq(true, false)) {
147150
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) {
151+
DataFrameRangeSuite.isTaskStarted = false
148152
val ex = intercept[SparkException] {
149-
spark.range(100000L).crossJoin(spark.range(100000L))
150-
.toDF("a", "b").agg(sum("a"), sum("b")).collect()
153+
spark.range(100000L).mapPartitions { x =>
154+
DataFrameRangeSuite.isTaskStarted = true
155+
x
156+
}.crossJoin(spark.range(100L)).toDF("a", "b").agg(sum("a"), sum("b")).collect()
151157
}
152158
ex.getCause() match {
153159
case null =>
154160
assert(ex.getMessage().contains("cancelled"))
155161
case cause: SparkException =>
156162
assert(cause.getMessage().contains("cancelled"))
157163
case cause: Throwable =>
158-
fail("Expected the casue to be SparkException, got " + cause.toString() + " instead.")
164+
fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.")
159165
}
160166
}
161167
eventually(timeout(20.seconds)) {
@@ -164,3 +170,8 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall
164170
}
165171
}
166172
}
173+
174+
object DataFrameRangeSuite {
175+
@volatile var isTaskStarted = false
176+
}
177+

0 commit comments

Comments
 (0)