@@ -3965,6 +3965,42 @@ def test_create_new_session_if_old_session_stopped(self):
39653965 finally :
39663966 newSession .stop ()
39673967
3968+ def test_active_session_with_None_and_not_None_context (self ):
3969+ from pyspark .context import SparkContext
3970+ from pyspark .conf import SparkConf
3971+ sc = SparkContext ._active_spark_context
3972+ self .assertEqual (sc , None )
3973+ activeSession = SparkSession .getActiveSession ()
3974+ self .assertEqual (activeSession , None )
3975+ sparkConf = SparkConf ()
3976+ sc = SparkContext .getOrCreate (sparkConf )
3977+ activeSession = sc ._jvm .SparkSession .getActiveSession ()
3978+ self .assertFalse (activeSession .isDefined ())
3979+ session = SparkSession (sc )
3980+ activeSession = sc ._jvm .SparkSession .getActiveSession ()
3981+ self .assertTrue (activeSession .isDefined ())
3982+ activeSession2 = SparkSession .getActiveSession ()
3983+ self .assertNotEqual (activeSession2 , None )
3984+
3985+
3986+ class SparkSessionTests3 (ReusedSQLTestCase ):
3987+
3988+ def test_get_active_session_after_create_dataframe (self ):
3989+ activeSession1 = SparkSession .getActiveSession ()
3990+ session1 = self .spark
3991+ self .assertEqual (session1 , activeSession1 )
3992+ session2 = self .spark .newSession ()
3993+ activeSession2 = SparkSession .getActiveSession ()
3994+ self .assertEqual (session1 , activeSession2 )
3995+ self .assertNotEqual (session2 , activeSession2 )
3996+ session2 .createDataFrame ([(1 , 'Alice' )], ['age' , 'name' ])
3997+ activeSession3 = SparkSession .getActiveSession ()
3998+ self .assertEqual (session2 , activeSession3 )
3999+ session1 .createDataFrame ([(1 , 'Alice' )], ['age' , 'name' ])
4000+ activeSession4 = SparkSession .getActiveSession ()
4001+ self .assertEqual (session1 , activeSession4 )
4002+ session2 .stop ()
4003+
39684004
39694005class UDFInitializationTests (unittest .TestCase ):
39704006 def tearDown (self ):
0 commit comments