1717
1818package org .apache .spark .sql
1919
20- import java .util .concurrent .{CountDownLatch , TimeUnit }
21-
2220import scala .concurrent .duration ._
2321import scala .math .abs
2422import scala .util .Random
2523
2624import org .scalatest .concurrent .Eventually
2725
28- import org .apache .spark .{SparkContext , SparkException }
29- import org .apache .spark .scheduler .{SparkListener , SparkListenerTaskStart }
26+ import org .apache .spark .{SparkException , TaskContext }
27+ import org .apache .spark .scheduler .{SparkListener , SparkListenerJobStart }
3028import org .apache .spark .sql .functions ._
3129import org .apache .spark .sql .internal .SQLConf
3230import org .apache .spark .sql .test .SharedSQLContext
@@ -154,53 +152,39 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall
154152 }
155153
156154 test(" Cancelling stage in a query with Range." ) {
157- // Save and restore the value because SparkContext is shared
158- val savedInterruptOnCancel = sparkContext
159- .getLocalProperty(SparkContext .SPARK_JOB_INTERRUPT_ON_CANCEL )
160-
161- try {
162- sparkContext.setLocalProperty(SparkContext .SPARK_JOB_INTERRUPT_ON_CANCEL , " true" )
163-
164- for (codegen <- Seq (true , false )) {
165- // This countdown latch used to make sure with all the stages cancelStage called in listener
166- val latch = new CountDownLatch (2 )
167-
168- val listener = new SparkListener {
169- override def onTaskStart (taskStart : SparkListenerTaskStart ): Unit = {
170- sparkContext.cancelStage(taskStart.stageId)
171- latch.countDown()
172- }
155+ val listener = new SparkListener {
156+ override def onJobStart (jobStart : SparkListenerJobStart ): Unit = {
157+ eventually(timeout(10 .seconds), interval(1 .millis)) {
158+ assert(DataFrameRangeSuite .stageToKill > 0 )
173159 }
160+ sparkContext.cancelStage(DataFrameRangeSuite .stageToKill)
161+ }
162+ }
174163
175- sparkContext.addSparkListener(listener)
176- withSQLConf(SQLConf .WHOLESTAGE_CODEGEN_ENABLED .key -> codegen.toString()) {
177- val ex = intercept[SparkException ] {
178- sparkContext.range(0 , 10000L , numSlices = 10 ).mapPartitions { x =>
179- x.synchronized {
180- x.wait()
181- }
182- x
183- }.toDF(" id" ).agg(sum(" id" )).collect()
184- }
185- ex.getCause() match {
186- case null =>
187- assert(ex.getMessage().contains(" cancelled" ))
188- case cause : SparkException =>
189- assert(cause.getMessage().contains(" cancelled" ))
190- case cause : Throwable =>
191- fail(" Expected the cause to be SparkException, got " + cause.toString() + " instead." )
192- }
164+ sparkContext.addSparkListener(listener)
165+ for (codegen <- Seq (true , false )) {
166+ withSQLConf(SQLConf .WHOLESTAGE_CODEGEN_ENABLED .key -> codegen.toString()) {
167+ DataFrameRangeSuite .stageToKill = - 1
168+ val ex = intercept[SparkException ] {
169+ spark.range(0 , 100000000000L , 1 , 1 ).map { x =>
170+ DataFrameRangeSuite .stageToKill = TaskContext .get().stageId()
171+ x
172+ }.toDF(" id" ).agg(sum(" id" )).collect()
193173 }
194- latch.await(20 , TimeUnit .SECONDS )
195- eventually(timeout(20 .seconds)) {
196- assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0 )
174+ ex.getCause() match {
175+ case null =>
176+ assert(ex.getMessage().contains(" cancelled" ))
177+ case cause : SparkException =>
178+ assert(cause.getMessage().contains(" cancelled" ))
179+ case cause : Throwable =>
180+ fail(" Expected the cause to be SparkException, got " + cause.toString() + " instead." )
197181 }
198- sparkContext.removeSparkListener(listener)
199182 }
200- } finally {
201- sparkContext.setLocalProperty( SparkContext . SPARK_JOB_INTERRUPT_ON_CANCEL ,
202- savedInterruptOnCancel)
183+ eventually(timeout( 20 .seconds)) {
184+ assert( sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0 )
185+ }
203186 }
187+ sparkContext.removeSparkListener(listener)
204188 }
205189
206190 test(" SPARK-20430 Initialize Range parameters in a driver side" ) {
@@ -220,3 +204,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall
220204 }
221205 }
222206}
207+
208+ object DataFrameRangeSuite {
209+ @ volatile var stageToKill = - 1
210+ }
0 commit comments