2424import tempfile
2525import random
2626import struct
27+ import shutil
2728from functools import reduce
2829
2930if sys .version_info [:2 ] <= (2 , 6 ):
@@ -59,12 +60,21 @@ def setUpClass(cls):
5960 @classmethod
6061 def tearDownClass (cls ):
6162 cls .sc .stop ()
63+ # Clean up in the JVM just in case there has been some issues in Python API
64+ jSparkContextOption = SparkContext ._jvm .SparkContext .get ()
65+ if jSparkContextOption .nonEmpty ():
66+ jSparkContextOption .get ().stop ()
6267
6368 def setUp (self ):
6469 self .ssc = StreamingContext (self .sc , self .duration )
6570
6671 def tearDown (self ):
67- self .ssc .stop (False )
72+ if self .ssc is not None :
73+ self .ssc .stop (False )
74+ # Clean up in the JVM just in case there has been some issues in Python API
75+ jStreamingContextOption = StreamingContext ._jvm .SparkContext .getActive ()
76+ if jStreamingContextOption .nonEmpty ():
77+ jStreamingContextOption .get ().stop (False )
6878
6979 def wait_for (self , result , n ):
7080 start_time = time .time ()
@@ -442,6 +452,7 @@ def test_reduce_by_invalid_window(self):
442452class StreamingContextTests (PySparkStreamingTestCase ):
443453
444454 duration = 0.1
455+ setupCalled = False
445456
446457 def _add_input_stream (self ):
447458 inputs = [range (1 , x ) for x in range (101 )]
@@ -515,10 +526,85 @@ def func(rdds):
515526
516527 self .assertEqual ([2 , 3 , 1 ], self ._take (dstream , 3 ))
517528
529+ def test_get_active (self ):
530+ self .assertEqual (StreamingContext .getActive (), None )
531+
532+ # Verify that getActive() returns the active context
533+ self .ssc .queueStream ([[1 ]]).foreachRDD (lambda rdd : rdd .count ())
534+ self .ssc .start ()
535+ self .assertEqual (StreamingContext .getActive (), self .ssc )
536+
537+ # Verify that getActive() returns None
538+ self .ssc .stop (False )
539+ self .assertEqual (StreamingContext .getActive (), None )
540+
541+ # Verify that if the Java context is stopped, then getActive() returns None
542+ self .ssc = StreamingContext (self .sc , self .duration )
543+ self .ssc .queueStream ([[1 ]]).foreachRDD (lambda rdd : rdd .count ())
544+ self .ssc .start ()
545+ self .assertEqual (StreamingContext .getActive (), self .ssc )
546+ self .ssc ._jssc .stop (False )
547+ self .assertEqual (StreamingContext .getActive (), None )
548+
549+ def test_get_active_or_create (self ):
550+ # Test StreamingContext.getActiveOrCreate() without checkpoint data
551+ # See CheckpointTests for tests with checkpoint data
552+ self .ssc = None
553+ self .assertEqual (StreamingContext .getActive (), None )
554+
555+ def setupFunc ():
556+ ssc = StreamingContext (self .sc , self .duration )
557+ ssc .queueStream ([[1 ]]).foreachRDD (lambda rdd : rdd .count ())
558+ self .setupCalled = True
559+ return ssc
560+
561+ # Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active
562+ self .setupCalled = False
563+ self .ssc = StreamingContext .getActiveOrCreate (None , setupFunc )
564+ self .assertTrue (self .setupCalled )
565+
566+ # Verify that getActiveOrCreate() retuns active context and does not call the setupFunc
567+ self .ssc .start ()
568+ self .setupCalled = False
569+ self .assertEqual (StreamingContext .getActiveOrCreate (None , setupFunc ), self .ssc )
570+ self .assertFalse (self .setupCalled )
571+
572+ # Verify that getActiveOrCreate() calls setupFunc after active context is stopped
573+ self .ssc .stop (False )
574+ self .setupCalled = False
575+ self .ssc = StreamingContext .getActiveOrCreate (None , setupFunc )
576+ self .assertTrue (self .setupCalled )
577+
578+ # Verify that if the Java context is stopped, then getActive() returns None
579+ self .ssc = StreamingContext (self .sc , self .duration )
580+ self .ssc .queueStream ([[1 ]]).foreachRDD (lambda rdd : rdd .count ())
581+ self .ssc .start ()
582+ self .assertEqual (StreamingContext .getActive (), self .ssc )
583+ self .ssc ._jssc .stop (False )
584+ self .setupCalled = False
585+ self .ssc = StreamingContext .getActiveOrCreate (None , setupFunc )
586+ self .assertTrue (self .setupCalled )
587+
518588
519589class CheckpointTests (unittest .TestCase ):
520590
521- def test_get_or_create (self ):
591+ setupCalled = False
592+
593+ @staticmethod
594+ def tearDownClass ():
595+ # Clean up in the JVM just in case there has been some issues in Python API
596+ jStreamingContextOption = StreamingContext ._jvm .SparkContext .getActive ()
597+ if jStreamingContextOption .nonEmpty ():
598+ jStreamingContextOption .get ().stop ()
599+ jSparkContextOption = SparkContext ._jvm .SparkContext .get ()
600+ if jSparkContextOption .nonEmpty ():
601+ jSparkContextOption .get ().stop ()
602+
603+ def tearDown (self ):
604+ if self .ssc is not None :
605+ self .ssc .stop (True )
606+
607+ def test_get_or_create_and_get_active_or_create (self ):
522608 inputd = tempfile .mkdtemp ()
523609 outputd = tempfile .mkdtemp () + "/"
524610
@@ -533,11 +619,12 @@ def setup():
533619 wc = dstream .updateStateByKey (updater )
534620 wc .map (lambda x : "%s,%d" % x ).saveAsTextFiles (outputd + "test" )
535621 wc .checkpoint (.5 )
622+ self .setupCalled = True
536623 return ssc
537624
538625 cpd = tempfile .mkdtemp ("test_streaming_cps" )
539- ssc = StreamingContext .getOrCreate (cpd , setup )
540- ssc .start ()
626+ self . ssc = StreamingContext .getOrCreate (cpd , setup )
627+ self . ssc .start ()
541628
542629 def check_output (n ):
543630 while not os .listdir (outputd ):
@@ -552,7 +639,7 @@ def check_output(n):
552639 # not finished
553640 time .sleep (0.01 )
554641 continue
555- ordd = ssc .sparkContext .textFile (p ).map (lambda line : line .split ("," ))
642+ ordd = self . ssc .sparkContext .textFile (p ).map (lambda line : line .split ("," ))
556643 d = ordd .values ().map (int ).collect ()
557644 if not d :
558645 time .sleep (0.01 )
@@ -568,13 +655,37 @@ def check_output(n):
568655
569656 check_output (1 )
570657 check_output (2 )
571- ssc .stop (True , True )
572658
659+ # Verify the getOrCreate() recovers from checkpoint files
660+ self .ssc .stop (True , True )
573661 time .sleep (1 )
574- ssc = StreamingContext .getOrCreate (cpd , setup )
575- ssc .start ()
662+ self .setupCalled = False
663+ self .ssc = StreamingContext .getOrCreate (cpd , setup )
664+ self .assertFalse (self .setupCalled )
665+ self .ssc .start ()
576666 check_output (3 )
577- ssc .stop (True , True )
667+
668+ # Verify the getActiveOrCreate() recovers from checkpoint files
669+ self .ssc .stop (True , True )
670+ time .sleep (1 )
671+ self .setupCalled = False
672+ self .ssc = StreamingContext .getActiveOrCreate (cpd , setup )
673+ self .assertFalse (self .setupCalled )
674+ self .ssc .start ()
675+ check_output (4 )
676+
677+ # Verify that getActiveOrCreate() returns active context
678+ self .setupCalled = False
679+ self .assertEquals (StreamingContext .getActiveOrCreate (cpd , setup ), self .ssc )
680+ self .assertFalse (self .setupCalled )
681+
682+ # Verify that getActiveOrCreate() calls setup() in absence of checkpoint files
683+ self .ssc .stop (True , True )
684+ shutil .rmtree (cpd ) # delete checkpoint directory
685+ self .setupCalled = False
686+ self .ssc = StreamingContext .getActiveOrCreate (cpd , setup )
687+ self .assertTrue (self .setupCalled )
688+ self .ssc .stop (True , True )
578689
579690
580691class KafkaStreamTests (PySparkStreamingTestCase ):
@@ -1134,7 +1245,7 @@ def search_kinesis_asl_assembly_jar():
11341245 testcases .append (KinesisStreamTests )
11351246 elif are_kinesis_tests_enabled is False :
11361247 sys .stderr .write ("Skipping all Kinesis Python tests as the optional Kinesis project was "
1137- "not compiled with -Pkinesis-asl profile . To run these tests, "
1248+ "not compiled into a JAR . To run these tests, "
11381249 "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly "
11391250 "streaming-kinesis-asl-assembly/assembly' or "
11401251 "'build/mvn -Pkinesis-asl package' before running this test." )
@@ -1150,4 +1261,4 @@ def search_kinesis_asl_assembly_jar():
11501261 for testcase in testcases :
11511262 sys .stderr .write ("[Running %s]\n " % (testcase ))
11521263 tests = unittest .TestLoader ().loadTestsFromTestCase (testcase )
1153- unittest .TextTestRunner (verbosity = 2 ).run (tests )
1264+ unittest .TextTestRunner (verbosity = 3 ).run (tests )
0 commit comments