Skip to content
57 changes: 54 additions & 3 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ class StreamingContext(object):
"""
_transformerSerializer = None

# Reference to a currently active StreamingContext
_activeContext = None

def __init__(self, sparkContext, batchDuration=None, jssc=None):
"""
Create a new StreamingContext.
Expand Down Expand Up @@ -142,10 +145,10 @@ def getOrCreate(cls, checkpointPath, setupFunc):
Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
recreated from the checkpoint data. If the data does not exist, then the provided setupFunc
will be used to create a JavaStreamingContext.
will be used to create a new context.

@param checkpointPath: Checkpoint directory used in an earlier JavaStreamingContext program
@param setupFunc: Function to create a new JavaStreamingContext and setup DStreams
@param checkpointPath: Checkpoint directory used in an earlier streaming program
@param setupFunc: Function to create a new context and setup DStreams
"""
# TODO: support checkpoint in HDFS
if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath):
Expand All @@ -170,6 +173,52 @@ def getOrCreate(cls, checkpointPath, setupFunc):
cls._transformerSerializer.ctx = sc
return StreamingContext(sc, None, jssc)

@classmethod
def getActive(cls):
"""
Return either the currently active StreamingContext (i.e., if there is a context started
but not stopped) or None.
"""
activePythonContext = cls._activeContext
if activePythonContext is not None:
# Verify that the current running Java StreamingContext is active and is the same one
# backing the supposedly active Python context
activePythonContextJavaId = activePythonContext._jssc.ssc().hashCode()
activeJvmContextOption = activePythonContext._jvm.StreamingContext.getActive()

if activeJvmContextOption.isEmpty():
cls._activeContext = None
elif activeJvmContextOption.get().hashCode() != activePythonContextJavaId:
cls._activeContext = None
raise Exception("JVM's active JavaStreamingContext is not the JavaStreamingContext "
"backing the action Python StreamingContext. This is unexpected.")
return cls._activeContext

@classmethod
def getActiveOrCreate(cls, checkpointPath, setupFunc):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's better to exchange the locations of checkpointPath and setupFunc, so that you can use checkpointPath = None here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that then this order parameters is opposite of getOrCreate which in my own experience is creating a lot confusion when one switches between getOrCreate and getActiveOrCreate. That's why I explicitly kept the ordering same, as it is more annoying for developers to get the parameters order wrong (more so in Python than Java/Scala) than explicitly provide a checkpoint path or None. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Let's keep the current order.

"""
Either return the active StreamingContext (i.e. currently started but not stopped),
or recreate a StreamingContext from checkpoint data or create a new StreamingContext
using the provided setupFunc function. If the checkpointPath is None or does not contain
valid checkpoint data, then setupFunc will be called to create a new context and setup
DStreams.

@param checkpointPath: Checkpoint directory used in an earlier streaming program. Can be
None if the intention is to always create a new context when there
is no active context.
@param setupFunc: Function to create a new JavaStreamingContext and setup DStreams
"""

if setupFunc is None:
raise Exception("setupFunc cannot be None")
activeContext = cls.getActive()
if activeContext is not None:
return activeContext
elif checkpointPath is not None:
return cls.getOrCreate(checkpointPath, setupFunc)
else:
return setupFunc()

@property
def sparkContext(self):
"""
Expand All @@ -182,6 +231,7 @@ def start(self):
Start the execution of the streams.
"""
self._jssc.start()
StreamingContext._activeContext = self

def awaitTermination(self, timeout=None):
"""
Expand Down Expand Up @@ -212,6 +262,7 @@ def stop(self, stopSparkContext=True, stopGraceFully=False):
of all received data to be completed
"""
self._jssc.stop(stopSparkContext, stopGraceFully)
StreamingContext._activeContext = None
if stopSparkContext:
self._sc.stop()

Expand Down
133 changes: 122 additions & 11 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import tempfile
import random
import struct
import shutil
from functools import reduce

if sys.version_info[:2] <= (2, 6):
Expand Down Expand Up @@ -59,12 +60,21 @@ def setUpClass(cls):
@classmethod
def tearDownClass(cls):
cls.sc.stop()
# Clean up in the JVM just in case there has been some issues in Python API
jSparkContextOption = SparkContext._jvm.SparkContext.get()
if jSparkContextOption.nonEmpty():
jSparkContextOption.get().stop()

def setUp(self):
self.ssc = StreamingContext(self.sc, self.duration)

def tearDown(self):
self.ssc.stop(False)
if self.ssc is not None:
self.ssc.stop(False)
# Clean up in the JVM just in case there has been some issues in Python API
jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive()
if jStreamingContextOption.nonEmpty():
jStreamingContextOption.get().stop(False)

def wait_for(self, result, n):
start_time = time.time()
Expand Down Expand Up @@ -442,6 +452,7 @@ def test_reduce_by_invalid_window(self):
class StreamingContextTests(PySparkStreamingTestCase):

duration = 0.1
setupCalled = False

def _add_input_stream(self):
inputs = [range(1, x) for x in range(101)]
Expand Down Expand Up @@ -515,10 +526,85 @@ def func(rdds):

self.assertEqual([2, 3, 1], self._take(dstream, 3))

def test_get_active(self):
self.assertEqual(StreamingContext.getActive(), None)

# Verify that getActive() returns the active context
self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
self.ssc.start()
self.assertEqual(StreamingContext.getActive(), self.ssc)

# Verify that getActive() returns None
self.ssc.stop(False)
self.assertEqual(StreamingContext.getActive(), None)

# Verify that if the Java context is stopped, then getActive() returns None
self.ssc = StreamingContext(self.sc, self.duration)
self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
self.ssc.start()
self.assertEqual(StreamingContext.getActive(), self.ssc)
self.ssc._jssc.stop(False)
self.assertEqual(StreamingContext.getActive(), None)

def test_get_active_or_create(self):
# Test StreamingContext.getActiveOrCreate() without checkpoint data
# See CheckpointTests for tests with checkpoint data
self.ssc = None
self.assertEqual(StreamingContext.getActive(), None)

def setupFunc():
ssc = StreamingContext(self.sc, self.duration)
ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
self.setupCalled = True
return ssc

# Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active
self.setupCalled = False
self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
self.assertTrue(self.setupCalled)

# Verify that getActiveOrCreate() retuns active context and does not call the setupFunc
self.ssc.start()
self.setupCalled = False
self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc)
self.assertFalse(self.setupCalled)

# Verify that getActiveOrCreate() calls setupFunc after active context is stopped
self.ssc.stop(False)
self.setupCalled = False
self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
self.assertTrue(self.setupCalled)

# Verify that if the Java context is stopped, then getActive() returns None
self.ssc = StreamingContext(self.sc, self.duration)
self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
self.ssc.start()
self.assertEqual(StreamingContext.getActive(), self.ssc)
self.ssc._jssc.stop(False)
self.setupCalled = False
self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
self.assertTrue(self.setupCalled)


class CheckpointTests(unittest.TestCase):

def test_get_or_create(self):
setupCalled = False

@staticmethod
def tearDownClass():
# Clean up in the JVM just in case there has been some issues in Python API
jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive()
if jStreamingContextOption.nonEmpty():
jStreamingContextOption.get().stop()
jSparkContextOption = SparkContext._jvm.SparkContext.get()
if jSparkContextOption.nonEmpty():
jSparkContextOption.get().stop()

def tearDown(self):
if self.ssc is not None:
self.ssc.stop(True)

def test_get_or_create_and_get_active_or_create(self):
inputd = tempfile.mkdtemp()
outputd = tempfile.mkdtemp() + "/"

Expand All @@ -533,11 +619,12 @@ def setup():
wc = dstream.updateStateByKey(updater)
wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test")
wc.checkpoint(.5)
self.setupCalled = True
return ssc

cpd = tempfile.mkdtemp("test_streaming_cps")
ssc = StreamingContext.getOrCreate(cpd, setup)
ssc.start()
self.ssc = StreamingContext.getOrCreate(cpd, setup)
self.ssc.start()

def check_output(n):
while not os.listdir(outputd):
Expand All @@ -552,7 +639,7 @@ def check_output(n):
# not finished
time.sleep(0.01)
continue
ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(","))
ordd = self.ssc.sparkContext.textFile(p).map(lambda line: line.split(","))
d = ordd.values().map(int).collect()
if not d:
time.sleep(0.01)
Expand All @@ -568,13 +655,37 @@ def check_output(n):

check_output(1)
check_output(2)
ssc.stop(True, True)

# Verify the getOrCreate() recovers from checkpoint files
self.ssc.stop(True, True)
time.sleep(1)
ssc = StreamingContext.getOrCreate(cpd, setup)
ssc.start()
self.setupCalled = False
self.ssc = StreamingContext.getOrCreate(cpd, setup)
self.assertFalse(self.setupCalled)
self.ssc.start()
check_output(3)
ssc.stop(True, True)

# Verify the getActiveOrCreate() recovers from checkpoint files
self.ssc.stop(True, True)
time.sleep(1)
self.setupCalled = False
self.ssc = StreamingContext.getActiveOrCreate(cpd, setup)
self.assertFalse(self.setupCalled)
self.ssc.start()
check_output(4)

# Verify that getActiveOrCreate() returns active context
self.setupCalled = False
self.assertEquals(StreamingContext.getActiveOrCreate(cpd, setup), self.ssc)
self.assertFalse(self.setupCalled)

# Verify that getActiveOrCreate() calls setup() in absence of checkpoint files
self.ssc.stop(True, True)
shutil.rmtree(cpd) # delete checkpoint directory
self.setupCalled = False
self.ssc = StreamingContext.getActiveOrCreate(cpd, setup)
self.assertTrue(self.setupCalled)
self.ssc.stop(True, True)


class KafkaStreamTests(PySparkStreamingTestCase):
Expand Down Expand Up @@ -1134,7 +1245,7 @@ def search_kinesis_asl_assembly_jar():
testcases.append(KinesisStreamTests)
elif are_kinesis_tests_enabled is False:
sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was "
"not compiled with -Pkinesis-asl profile. To run these tests, "
"not compiled into a JAR. To run these tests, "
"you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly "
"streaming-kinesis-asl-assembly/assembly' or "
"'build/mvn -Pkinesis-asl package' before running this test.")
Expand All @@ -1150,4 +1261,4 @@ def search_kinesis_asl_assembly_jar():
for testcase in testcases:
sys.stderr.write("[Running %s]\n" % (testcase))
tests = unittest.TestLoader().loadTestsFromTestCase(testcase)
unittest.TextTestRunner(verbosity=2).run(tests)
unittest.TextTestRunner(verbosity=3).run(tests)
2 changes: 1 addition & 1 deletion python/run-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def main():
else:
log_level = logging.INFO
logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s")
LOGGER.info("Running PySpark tests. Output is in python/%s", LOG_FILE)
LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the previous code means the stdout is in python/unit-tests.log, which is correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you see how LOG_FILE is defined, it is already an absolute path.
https://github.com/tdas/spark/blob/SPARK-9572/python/run-tests.py#L69

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, right. Good catch about the path.

if os.path.exists(LOG_FILE):
os.remove(LOG_FILE)
python_execs = opts.python_executables.split(',')
Expand Down