@@ -469,13 +469,18 @@ def setUp(self):
469
469
self .batachDuration = Milliseconds (500 )
470
470
self .sparkHome = "SomeDir"
471
471
self .envPair = {"key" : "value" }
472
+ self .ssc = None
473
+ self .sc = None
472
474
473
475
def tearDown (self ):
474
476
# Do not call pyspark.streaming.context.StreamingContext.stop directly because
475
477
# we do not wait to shutdown py4j client.
476
478
# We need change this simply calll streamingConxt.Stop
477
- self .ssc ._jssc .stop ()
478
- self .ssc ._sc .stop ()
479
+ #self.ssc._jssc.stop()
480
+ if self .ssc is not None :
481
+ self .ssc .stop ()
482
+ if self .sc is not None :
483
+ self .sc .stop ()
479
484
# Why does it long time to terminate StremaingContext and SparkContext?
480
485
# Should we change the sleep time if this depends on machine spec?
481
486
time .sleep (1 )
@@ -486,48 +491,67 @@ def tearDownClass(cls):
486
491
SparkContext ._gateway ._shutdown_callback_server ()
487
492
488
493
def test_from_no_conf_constructor (self ):
489
- ssc = StreamingContext (master = self .master , appName = self .appName , duration = batachDuration )
494
+ self .ssc = StreamingContext (master = self .master , appName = self .appName ,
495
+ duration = self .batachDuration )
490
496
# Alternative call master: ssc.sparkContext.master
491
497
# I try to make code close to Scala.
492
- self .assertEqual (ssc .sparkContext ._conf .get ("spark.master" ), self .master )
493
- self .assertEqual (ssc .sparkContext ._conf .get ("spark.app.name" ), self .appName )
498
+ self .assertEqual (self . ssc .sparkContext ._conf .get ("spark.master" ), self .master )
499
+ self .assertEqual (self . ssc .sparkContext ._conf .get ("spark.app.name" ), self .appName )
494
500
495
501
def test_from_no_conf_plus_spark_home (self ):
496
- ssc = StreamingContext (master = self .master , appName = self .appName ,
497
- sparkHome = self .sparkHome , duration = batachDuration )
498
- self .assertEqual (ssc .sparkContext ._conf .get ("spark.home" ), self .sparkHome )
502
+ self .ssc = StreamingContext (master = self .master , appName = self .appName ,
503
+ sparkHome = self .sparkHome , duration = self .batachDuration )
504
+ self .assertEqual (self .ssc .sparkContext ._conf .get ("spark.home" ), self .sparkHome )
505
+
506
+ def test_from_no_conf_plus_spark_home_plus_env (self ):
507
+ self .ssc = StreamingContext (master = self .master , appName = self .appName ,
508
+ sparkHome = self .sparkHome , environment = self .envPair ,
509
+ duration = self .batachDuration )
510
+ self .assertEqual (self .ssc .sparkContext ._conf .get ("spark.executorEnv.key" ), self .envPair ["key" ])
499
511
500
512
def test_from_existing_spark_context (self ):
501
- sc = SparkContext (master = self .master , appName = self .appName )
502
- ssc = StreamingContext (sparkContext = sc )
513
+ self . sc = SparkContext (master = self .master , appName = self .appName )
514
+ self . ssc = StreamingContext (sparkContext = self . sc , duration = self . batachDuration )
503
515
504
516
def test_existing_spark_context_with_settings (self ):
505
517
conf = SparkConf ()
506
518
conf .set ("spark.cleaner.ttl" , "10" )
507
- sc = SparkContext (master = self .master , appName = self .appName , conf = conf )
508
- ssc = StreamingContext (context = sc )
509
- self .assertEqual (int (ssc .sparkContext ._conf .get ("spark.cleaner.ttl" )), 10 )
510
-
511
- def _addInputStream (self , s ):
512
- test_inputs = map (lambda x : range (1 , x ), range (5 , 101 ))
513
- # make sure numSlice is 2 due to deserializer proglem in pyspark
514
- s ._testInputStream (test_inputs , 2 )
515
-
516
- def test_from_no_conf_plus_spark_home_plus_env (self ):
517
- pass
519
+ self .sc = SparkContext (master = self .master , appName = self .appName , conf = conf )
520
+ self .ssc = StreamingContext (sparkContext = self .sc , duration = self .batachDuration )
521
+ self .assertEqual (int (self .ssc .sparkContext ._conf .get ("spark.cleaner.ttl" )), 10 )
518
522
519
523
def test_from_conf_with_settings (self ):
520
- pass
524
+ conf = SparkConf ()
525
+ conf .set ("spark.cleaner.ttl" , "10" )
526
+ conf .setMaster (self .master )
527
+ conf .setAppName (self .appName )
528
+ self .ssc = StreamingContext (conf = conf , duration = self .batachDuration )
529
+ self .assertEqual (int (self .ssc .sparkContext ._conf .get ("spark.cleaner.ttl" )), 10 )
521
530
522
531
def test_stop_only_streaming_context (self ):
523
- pass
524
-
525
- def test_await_termination (self ):
526
- pass
527
-
528
-
532
+ self . sc = SparkContext ( master = self . master , appName = self . appName )
533
+ self . ssc = StreamingContext ( sparkContext = self . sc , duration = self . batachDuration )
534
+ self . _addInputStream (self . ssc )
535
+ self . ssc . start ()
536
+ self . ssc . stop ( False )
537
+ self . assertEqual ( len ( self . sc . parallelize ( range ( 5 ), 5 ). glom (). collect ()), 5 )
529
538
539
+ def test_stop_multiple_times (self ):
540
+ self .ssc = StreamingContext (master = self .master , appName = self .appName ,
541
+ duration = self .batachDuration )
542
+ self ._addInputStream (self .ssc )
543
+ self .ssc .start ()
544
+ self .ssc .stop ()
545
+ self .ssc .stop ()
530
546
547
+ def _addInputStream (self , s ):
548
+ # Make sure each length of input is over 3 and
549
+ # numSlice is 2 due to deserializer problem in pyspark.streaming
550
+ test_inputs = map (lambda x : range (1 , x ), range (5 , 101 ))
551
+ test_stream = s ._testInputStream (test_inputs , 2 )
552
+ # Register fake output operation
553
+ result = list ()
554
+ test_stream ._test_output (result )
531
555
532
556
if __name__ == "__main__" :
533
557
unittest .main ()
0 commit comments