Skip to content

Commit 5b4e37c

Browse files
author
Matthew Farrellee
committed
[SPARK-3458] enable python "with" statements for SparkContext
allow for best practice code, try: sc = SparkContext() app(sc) finally: sc.stop() to be written using a "with" statement, with SparkContext() as sc: app(sc)
1 parent c419e4f commit 5b4e37c

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

python/pyspark/context.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,20 @@ def _ensure_initialized(cls, instance=None, gateway=None):
232232
else:
233233
SparkContext._active_spark_context = instance
234234

235+
def __enter__(self):
236+
"""
237+
Enable 'with SparkContext(...) as sc: app(sc)' syntax.
238+
"""
239+
return self
240+
241+
def __exit__(self, type, value, trace):
242+
"""
243+
Enable 'with SparkContext(...) as sc: app' syntax.
244+
245+
Specifically stop the context on exit of the with block.
246+
"""
247+
self.stop()
248+
235249
@classmethod
236250
def setSystemProperty(cls, key, value):
237251
"""

python/pyspark/tests.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,6 +1254,35 @@ def test_single_script_on_cluster(self):
12541254
self.assertIn("[2, 4, 6]", out)
12551255

12561256

1257+
class ContextStopTests(unittest.TestCase):
1258+
1259+
def test_stop(self):
1260+
sc = SparkContext()
1261+
self.assertNotEqual(SparkContext._active_spark_context, None)
1262+
sc.stop()
1263+
self.assertEqual(SparkContext._active_spark_context, None)
1264+
1265+
def test_with(self):
1266+
with SparkContext() as sc:
1267+
self.assertNotEqual(SparkContext._active_spark_context, None)
1268+
self.assertEqual(SparkContext._active_spark_context, None)
1269+
1270+
def test_with_exception(self):
1271+
try:
1272+
with SparkContext() as sc:
1273+
self.assertNotEqual(SparkContext._active_spark_context, None)
1274+
raise Exception()
1275+
except:
1276+
pass
1277+
self.assertEqual(SparkContext._active_spark_context, None)
1278+
1279+
def test_with_stop(self):
1280+
with SparkContext() as sc:
1281+
self.assertNotEqual(SparkContext._active_spark_context, None)
1282+
sc.stop()
1283+
self.assertEqual(SparkContext._active_spark_context, None)
1284+
1285+
12571286
@unittest.skipIf(not _have_scipy, "SciPy not installed")
12581287
class SciPyTests(PySparkTestCase):
12591288

0 commit comments

Comments
 (0)