Skip to content

Commit ab11a3c

Browse files
huaxingaoJackey Lee
authored andcommitted
[SPARK-25255][PYTHON] Add getActiveSession to SparkSession in PySpark
## What changes were proposed in this pull request? add getActiveSession in session.py ## How was this patch tested? add doctest Closes apache#22295 from huaxingao/spark25255. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: Holden Karau <holden@pigscanfly.ca>
1 parent 44922c5 commit ab11a3c

File tree

2 files changed

+181
-0
lines changed

2 files changed

+181
-0
lines changed

python/pyspark/sql/session.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def getOrCreate(self):
192192
"""A class attribute having a :class:`Builder` to construct :class:`SparkSession` instances"""
193193

194194
_instantiatedSession = None
195+
_activeSession = None
195196

196197
@ignore_unicode_prefix
197198
def __init__(self, sparkContext, jsparkSession=None):
@@ -233,7 +234,9 @@ def __init__(self, sparkContext, jsparkSession=None):
233234
if SparkSession._instantiatedSession is None \
234235
or SparkSession._instantiatedSession._sc._jsc is None:
235236
SparkSession._instantiatedSession = self
237+
SparkSession._activeSession = self
236238
self._jvm.SparkSession.setDefaultSession(self._jsparkSession)
239+
self._jvm.SparkSession.setActiveSession(self._jsparkSession)
237240

238241
def _repr_html_(self):
239242
return """
@@ -255,6 +258,29 @@ def newSession(self):
255258
"""
256259
return self.__class__(self._sc, self._jsparkSession.newSession())
257260

261+
@classmethod
262+
@since(3.0)
263+
def getActiveSession(cls):
264+
"""
265+
Returns the active SparkSession for the current thread, returned by the builder.
266+
>>> s = SparkSession.getActiveSession()
267+
>>> l = [('Alice', 1)]
268+
>>> rdd = s.sparkContext.parallelize(l)
269+
>>> df = s.createDataFrame(rdd, ['name', 'age'])
270+
>>> df.select("age").collect()
271+
[Row(age=1)]
272+
"""
273+
from pyspark import SparkContext
274+
sc = SparkContext._active_spark_context
275+
if sc is None:
276+
return None
277+
else:
278+
if sc._jvm.SparkSession.getActiveSession().isDefined():
279+
SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get())
280+
return SparkSession._activeSession
281+
else:
282+
return None
283+
258284
@property
259285
@since(2.0)
260286
def sparkContext(self):
@@ -671,6 +697,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
671697
...
672698
Py4JJavaError: ...
673699
"""
700+
SparkSession._activeSession = self
701+
self._jvm.SparkSession.setActiveSession(self._jsparkSession)
674702
if isinstance(data, DataFrame):
675703
raise TypeError("data is already a DataFrame")
676704

@@ -826,7 +854,9 @@ def stop(self):
826854
self._sc.stop()
827855
# We should clean the default session up. See SPARK-23228.
828856
self._jvm.SparkSession.clearDefaultSession()
857+
self._jvm.SparkSession.clearActiveSession()
829858
SparkSession._instantiatedSession = None
859+
SparkSession._activeSession = None
830860

831861
@since(2.0)
832862
def __enter__(self):

python/pyspark/sql/tests.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3985,6 +3985,157 @@ def test_jvm_default_session_already_set(self):
39853985
spark.stop()
39863986

39873987

3988+
class SparkSessionTests2(unittest.TestCase):
3989+
3990+
def test_active_session(self):
3991+
spark = SparkSession.builder \
3992+
.master("local") \
3993+
.getOrCreate()
3994+
try:
3995+
activeSession = SparkSession.getActiveSession()
3996+
df = activeSession.createDataFrame([(1, 'Alice')], ['age', 'name'])
3997+
self.assertEqual(df.collect(), [Row(age=1, name=u'Alice')])
3998+
finally:
3999+
spark.stop()
4000+
4001+
def test_get_active_session_when_no_active_session(self):
4002+
active = SparkSession.getActiveSession()
4003+
self.assertEqual(active, None)
4004+
spark = SparkSession.builder \
4005+
.master("local") \
4006+
.getOrCreate()
4007+
active = SparkSession.getActiveSession()
4008+
self.assertEqual(active, spark)
4009+
spark.stop()
4010+
active = SparkSession.getActiveSession()
4011+
self.assertEqual(active, None)
4012+
4013+
def test_SparkSession(self):
4014+
spark = SparkSession.builder \
4015+
.master("local") \
4016+
.config("some-config", "v2") \
4017+
.getOrCreate()
4018+
try:
4019+
self.assertEqual(spark.conf.get("some-config"), "v2")
4020+
self.assertEqual(spark.sparkContext._conf.get("some-config"), "v2")
4021+
self.assertEqual(spark.version, spark.sparkContext.version)
4022+
spark.sql("CREATE DATABASE test_db")
4023+
spark.catalog.setCurrentDatabase("test_db")
4024+
self.assertEqual(spark.catalog.currentDatabase(), "test_db")
4025+
spark.sql("CREATE TABLE table1 (name STRING, age INT) USING parquet")
4026+
self.assertEqual(spark.table("table1").columns, ['name', 'age'])
4027+
self.assertEqual(spark.range(3).count(), 3)
4028+
finally:
4029+
spark.stop()
4030+
4031+
def test_global_default_session(self):
4032+
spark = SparkSession.builder \
4033+
.master("local") \
4034+
.getOrCreate()
4035+
try:
4036+
self.assertEqual(SparkSession.builder.getOrCreate(), spark)
4037+
finally:
4038+
spark.stop()
4039+
4040+
def test_default_and_active_session(self):
4041+
spark = SparkSession.builder \
4042+
.master("local") \
4043+
.getOrCreate()
4044+
activeSession = spark._jvm.SparkSession.getActiveSession()
4045+
defaultSession = spark._jvm.SparkSession.getDefaultSession()
4046+
try:
4047+
self.assertEqual(activeSession, defaultSession)
4048+
finally:
4049+
spark.stop()
4050+
4051+
def test_config_option_propagated_to_existing_session(self):
4052+
session1 = SparkSession.builder \
4053+
.master("local") \
4054+
.config("spark-config1", "a") \
4055+
.getOrCreate()
4056+
self.assertEqual(session1.conf.get("spark-config1"), "a")
4057+
session2 = SparkSession.builder \
4058+
.config("spark-config1", "b") \
4059+
.getOrCreate()
4060+
try:
4061+
self.assertEqual(session1, session2)
4062+
self.assertEqual(session1.conf.get("spark-config1"), "b")
4063+
finally:
4064+
session1.stop()
4065+
4066+
def test_new_session(self):
4067+
session = SparkSession.builder \
4068+
.master("local") \
4069+
.getOrCreate()
4070+
newSession = session.newSession()
4071+
try:
4072+
self.assertNotEqual(session, newSession)
4073+
finally:
4074+
session.stop()
4075+
newSession.stop()
4076+
4077+
def test_create_new_session_if_old_session_stopped(self):
4078+
session = SparkSession.builder \
4079+
.master("local") \
4080+
.getOrCreate()
4081+
session.stop()
4082+
newSession = SparkSession.builder \
4083+
.master("local") \
4084+
.getOrCreate()
4085+
try:
4086+
self.assertNotEqual(session, newSession)
4087+
finally:
4088+
newSession.stop()
4089+
4090+
def test_active_session_with_None_and_not_None_context(self):
4091+
from pyspark.context import SparkContext
4092+
from pyspark.conf import SparkConf
4093+
sc = None
4094+
session = None
4095+
try:
4096+
sc = SparkContext._active_spark_context
4097+
self.assertEqual(sc, None)
4098+
activeSession = SparkSession.getActiveSession()
4099+
self.assertEqual(activeSession, None)
4100+
sparkConf = SparkConf()
4101+
sc = SparkContext.getOrCreate(sparkConf)
4102+
activeSession = sc._jvm.SparkSession.getActiveSession()
4103+
self.assertFalse(activeSession.isDefined())
4104+
session = SparkSession(sc)
4105+
activeSession = sc._jvm.SparkSession.getActiveSession()
4106+
self.assertTrue(activeSession.isDefined())
4107+
activeSession2 = SparkSession.getActiveSession()
4108+
self.assertNotEqual(activeSession2, None)
4109+
finally:
4110+
if session is not None:
4111+
session.stop()
4112+
if sc is not None:
4113+
sc.stop()
4114+
4115+
4116+
class SparkSessionTests3(ReusedSQLTestCase):
4117+
4118+
def test_get_active_session_after_create_dataframe(self):
4119+
session2 = None
4120+
try:
4121+
activeSession1 = SparkSession.getActiveSession()
4122+
session1 = self.spark
4123+
self.assertEqual(session1, activeSession1)
4124+
session2 = self.spark.newSession()
4125+
activeSession2 = SparkSession.getActiveSession()
4126+
self.assertEqual(session1, activeSession2)
4127+
self.assertNotEqual(session2, activeSession2)
4128+
session2.createDataFrame([(1, 'Alice')], ['age', 'name'])
4129+
activeSession3 = SparkSession.getActiveSession()
4130+
self.assertEqual(session2, activeSession3)
4131+
session1.createDataFrame([(1, 'Alice')], ['age', 'name'])
4132+
activeSession4 = SparkSession.getActiveSession()
4133+
self.assertEqual(session1, activeSession4)
4134+
finally:
4135+
if session2 is not None:
4136+
session2.stop()
4137+
4138+
39884139
class UDFInitializationTests(unittest.TestCase):
39894140
def tearDown(self):
39904141
if SparkSession._instantiatedSession is not None:

0 commit comments

Comments
 (0)