Skip to content

Commit 5b8bb1b

Browse files
committed
[SPARK-9572] [STREAMING] [PYSPARK] Added StreamingContext.getActiveOrCreate() in Python
Author: Tathagata Das <tathagata.das1565@gmail.com> Closes #8080 from tdas/SPARK-9572 and squashes the following commits: 64a231d [Tathagata Das] Fix based on comments 741a0d0 [Tathagata Das] Fixed style f4f094c [Tathagata Das] Tweaked test 9afcdbe [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-9572 e21488d [Tathagata Das] Minor update 1a371d9 [Tathagata Das] Addressed comments. 60479da [Tathagata Das] Fixed indent 9c2da9c [Tathagata Das] Fixed bugs b5bd32c [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-9572 b55b348 [Tathagata Das] Removed prints 5781728 [Tathagata Das] Fix style issues b711214 [Tathagata Das] Reverted run-tests.py 643b59d [Tathagata Das] Revert unnecessary change 150e58c [Tathagata Das] Added StreamingContext.getActiveOrCreate() in Python
1 parent dbd778d commit 5b8bb1b

File tree

3 files changed

+177
-15
lines changed

3 files changed

+177
-15
lines changed

python/pyspark/streaming/context.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ class StreamingContext(object):
8686
"""
8787
_transformerSerializer = None
8888

89+
# Reference to a currently active StreamingContext
90+
_activeContext = None
91+
8992
def __init__(self, sparkContext, batchDuration=None, jssc=None):
9093
"""
9194
Create a new StreamingContext.
@@ -142,10 +145,10 @@ def getOrCreate(cls, checkpointPath, setupFunc):
142145
Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
143146
If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
144147
recreated from the checkpoint data. If the data does not exist, then the provided setupFunc
145-
will be used to create a JavaStreamingContext.
148+
will be used to create a new context.
146149
147-
@param checkpointPath: Checkpoint directory used in an earlier JavaStreamingContext program
148-
@param setupFunc: Function to create a new JavaStreamingContext and setup DStreams
150+
@param checkpointPath: Checkpoint directory used in an earlier streaming program
151+
@param setupFunc: Function to create a new context and setup DStreams
149152
"""
150153
# TODO: support checkpoint in HDFS
151154
if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath):
@@ -170,6 +173,52 @@ def getOrCreate(cls, checkpointPath, setupFunc):
170173
cls._transformerSerializer.ctx = sc
171174
return StreamingContext(sc, None, jssc)
172175

176+
@classmethod
177+
def getActive(cls):
178+
"""
179+
Return either the currently active StreamingContext (i.e., if there is a context started
180+
but not stopped) or None.
181+
"""
182+
activePythonContext = cls._activeContext
183+
if activePythonContext is not None:
184+
# Verify that the current running Java StreamingContext is active and is the same one
185+
# backing the supposedly active Python context
186+
activePythonContextJavaId = activePythonContext._jssc.ssc().hashCode()
187+
activeJvmContextOption = activePythonContext._jvm.StreamingContext.getActive()
188+
189+
if activeJvmContextOption.isEmpty():
190+
cls._activeContext = None
191+
elif activeJvmContextOption.get().hashCode() != activePythonContextJavaId:
192+
cls._activeContext = None
193+
raise Exception("JVM's active JavaStreamingContext is not the JavaStreamingContext "
194+
"backing the action Python StreamingContext. This is unexpected.")
195+
return cls._activeContext
196+
197+
@classmethod
198+
def getActiveOrCreate(cls, checkpointPath, setupFunc):
199+
"""
200+
Either return the active StreamingContext (i.e. currently started but not stopped),
201+
or recreate a StreamingContext from checkpoint data or create a new StreamingContext
202+
using the provided setupFunc function. If the checkpointPath is None or does not contain
203+
valid checkpoint data, then setupFunc will be called to create a new context and setup
204+
DStreams.
205+
206+
@param checkpointPath: Checkpoint directory used in an earlier streaming program. Can be
207+
None if the intention is to always create a new context when there
208+
is no active context.
209+
@param setupFunc: Function to create a new JavaStreamingContext and setup DStreams
210+
"""
211+
212+
if setupFunc is None:
213+
raise Exception("setupFunc cannot be None")
214+
activeContext = cls.getActive()
215+
if activeContext is not None:
216+
return activeContext
217+
elif checkpointPath is not None:
218+
return cls.getOrCreate(checkpointPath, setupFunc)
219+
else:
220+
return setupFunc()
221+
173222
@property
174223
def sparkContext(self):
175224
"""
@@ -182,6 +231,7 @@ def start(self):
182231
Start the execution of the streams.
183232
"""
184233
self._jssc.start()
234+
StreamingContext._activeContext = self
185235

186236
def awaitTermination(self, timeout=None):
187237
"""
@@ -212,6 +262,7 @@ def stop(self, stopSparkContext=True, stopGraceFully=False):
212262
of all received data to be completed
213263
"""
214264
self._jssc.stop(stopSparkContext, stopGraceFully)
265+
StreamingContext._activeContext = None
215266
if stopSparkContext:
216267
self._sc.stop()
217268

python/pyspark/streaming/tests.py

Lines changed: 122 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import tempfile
2525
import random
2626
import struct
27+
import shutil
2728
from functools import reduce
2829

2930
if sys.version_info[:2] <= (2, 6):
@@ -59,12 +60,21 @@ def setUpClass(cls):
5960
@classmethod
6061
def tearDownClass(cls):
6162
cls.sc.stop()
63+
# Clean up in the JVM just in case there has been some issues in Python API
64+
jSparkContextOption = SparkContext._jvm.SparkContext.get()
65+
if jSparkContextOption.nonEmpty():
66+
jSparkContextOption.get().stop()
6267

6368
def setUp(self):
6469
self.ssc = StreamingContext(self.sc, self.duration)
6570

6671
def tearDown(self):
67-
self.ssc.stop(False)
72+
if self.ssc is not None:
73+
self.ssc.stop(False)
74+
# Clean up in the JVM just in case there has been some issues in Python API
75+
jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive()
76+
if jStreamingContextOption.nonEmpty():
77+
jStreamingContextOption.get().stop(False)
6878

6979
def wait_for(self, result, n):
7080
start_time = time.time()
@@ -442,6 +452,7 @@ def test_reduce_by_invalid_window(self):
442452
class StreamingContextTests(PySparkStreamingTestCase):
443453

444454
duration = 0.1
455+
setupCalled = False
445456

446457
def _add_input_stream(self):
447458
inputs = [range(1, x) for x in range(101)]
@@ -515,10 +526,85 @@ def func(rdds):
515526

516527
self.assertEqual([2, 3, 1], self._take(dstream, 3))
517528

529+
def test_get_active(self):
530+
self.assertEqual(StreamingContext.getActive(), None)
531+
532+
# Verify that getActive() returns the active context
533+
self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
534+
self.ssc.start()
535+
self.assertEqual(StreamingContext.getActive(), self.ssc)
536+
537+
# Verify that getActive() returns None
538+
self.ssc.stop(False)
539+
self.assertEqual(StreamingContext.getActive(), None)
540+
541+
# Verify that if the Java context is stopped, then getActive() returns None
542+
self.ssc = StreamingContext(self.sc, self.duration)
543+
self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
544+
self.ssc.start()
545+
self.assertEqual(StreamingContext.getActive(), self.ssc)
546+
self.ssc._jssc.stop(False)
547+
self.assertEqual(StreamingContext.getActive(), None)
548+
549+
def test_get_active_or_create(self):
550+
# Test StreamingContext.getActiveOrCreate() without checkpoint data
551+
# See CheckpointTests for tests with checkpoint data
552+
self.ssc = None
553+
self.assertEqual(StreamingContext.getActive(), None)
554+
555+
def setupFunc():
556+
ssc = StreamingContext(self.sc, self.duration)
557+
ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
558+
self.setupCalled = True
559+
return ssc
560+
561+
# Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active
562+
self.setupCalled = False
563+
self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
564+
self.assertTrue(self.setupCalled)
565+
566+
# Verify that getActiveOrCreate() retuns active context and does not call the setupFunc
567+
self.ssc.start()
568+
self.setupCalled = False
569+
self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc)
570+
self.assertFalse(self.setupCalled)
571+
572+
# Verify that getActiveOrCreate() calls setupFunc after active context is stopped
573+
self.ssc.stop(False)
574+
self.setupCalled = False
575+
self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
576+
self.assertTrue(self.setupCalled)
577+
578+
# Verify that if the Java context is stopped, then getActive() returns None
579+
self.ssc = StreamingContext(self.sc, self.duration)
580+
self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
581+
self.ssc.start()
582+
self.assertEqual(StreamingContext.getActive(), self.ssc)
583+
self.ssc._jssc.stop(False)
584+
self.setupCalled = False
585+
self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
586+
self.assertTrue(self.setupCalled)
587+
518588

519589
class CheckpointTests(unittest.TestCase):
520590

521-
def test_get_or_create(self):
591+
setupCalled = False
592+
593+
@staticmethod
594+
def tearDownClass():
595+
# Clean up in the JVM just in case there has been some issues in Python API
596+
jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive()
597+
if jStreamingContextOption.nonEmpty():
598+
jStreamingContextOption.get().stop()
599+
jSparkContextOption = SparkContext._jvm.SparkContext.get()
600+
if jSparkContextOption.nonEmpty():
601+
jSparkContextOption.get().stop()
602+
603+
def tearDown(self):
604+
if self.ssc is not None:
605+
self.ssc.stop(True)
606+
607+
def test_get_or_create_and_get_active_or_create(self):
522608
inputd = tempfile.mkdtemp()
523609
outputd = tempfile.mkdtemp() + "/"
524610

@@ -533,11 +619,12 @@ def setup():
533619
wc = dstream.updateStateByKey(updater)
534620
wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test")
535621
wc.checkpoint(.5)
622+
self.setupCalled = True
536623
return ssc
537624

538625
cpd = tempfile.mkdtemp("test_streaming_cps")
539-
ssc = StreamingContext.getOrCreate(cpd, setup)
540-
ssc.start()
626+
self.ssc = StreamingContext.getOrCreate(cpd, setup)
627+
self.ssc.start()
541628

542629
def check_output(n):
543630
while not os.listdir(outputd):
@@ -552,7 +639,7 @@ def check_output(n):
552639
# not finished
553640
time.sleep(0.01)
554641
continue
555-
ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(","))
642+
ordd = self.ssc.sparkContext.textFile(p).map(lambda line: line.split(","))
556643
d = ordd.values().map(int).collect()
557644
if not d:
558645
time.sleep(0.01)
@@ -568,13 +655,37 @@ def check_output(n):
568655

569656
check_output(1)
570657
check_output(2)
571-
ssc.stop(True, True)
572658

659+
# Verify the getOrCreate() recovers from checkpoint files
660+
self.ssc.stop(True, True)
573661
time.sleep(1)
574-
ssc = StreamingContext.getOrCreate(cpd, setup)
575-
ssc.start()
662+
self.setupCalled = False
663+
self.ssc = StreamingContext.getOrCreate(cpd, setup)
664+
self.assertFalse(self.setupCalled)
665+
self.ssc.start()
576666
check_output(3)
577-
ssc.stop(True, True)
667+
668+
# Verify the getActiveOrCreate() recovers from checkpoint files
669+
self.ssc.stop(True, True)
670+
time.sleep(1)
671+
self.setupCalled = False
672+
self.ssc = StreamingContext.getActiveOrCreate(cpd, setup)
673+
self.assertFalse(self.setupCalled)
674+
self.ssc.start()
675+
check_output(4)
676+
677+
# Verify that getActiveOrCreate() returns active context
678+
self.setupCalled = False
679+
self.assertEquals(StreamingContext.getActiveOrCreate(cpd, setup), self.ssc)
680+
self.assertFalse(self.setupCalled)
681+
682+
# Verify that getActiveOrCreate() calls setup() in absence of checkpoint files
683+
self.ssc.stop(True, True)
684+
shutil.rmtree(cpd) # delete checkpoint directory
685+
self.setupCalled = False
686+
self.ssc = StreamingContext.getActiveOrCreate(cpd, setup)
687+
self.assertTrue(self.setupCalled)
688+
self.ssc.stop(True, True)
578689

579690

580691
class KafkaStreamTests(PySparkStreamingTestCase):
@@ -1134,7 +1245,7 @@ def search_kinesis_asl_assembly_jar():
11341245
testcases.append(KinesisStreamTests)
11351246
elif are_kinesis_tests_enabled is False:
11361247
sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was "
1137-
"not compiled with -Pkinesis-asl profile. To run these tests, "
1248+
"not compiled into a JAR. To run these tests, "
11381249
"you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly "
11391250
"streaming-kinesis-asl-assembly/assembly' or "
11401251
"'build/mvn -Pkinesis-asl package' before running this test.")
@@ -1150,4 +1261,4 @@ def search_kinesis_asl_assembly_jar():
11501261
for testcase in testcases:
11511262
sys.stderr.write("[Running %s]\n" % (testcase))
11521263
tests = unittest.TestLoader().loadTestsFromTestCase(testcase)
1153-
unittest.TextTestRunner(verbosity=2).run(tests)
1264+
unittest.TextTestRunner(verbosity=3).run(tests)

python/run-tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def main():
158158
else:
159159
log_level = logging.INFO
160160
logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s")
161-
LOGGER.info("Running PySpark tests. Output is in python/%s", LOG_FILE)
161+
LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE)
162162
if os.path.exists(LOG_FILE):
163163
os.remove(LOG_FILE)
164164
python_execs = opts.python_executables.split(',')

0 commit comments

Comments
 (0)