@@ -3985,6 +3985,157 @@ def test_jvm_default_session_already_set(self):
39853985 spark .stop ()
39863986
39873987
3988+ class SparkSessionTests2 (unittest .TestCase ):
3989+
3990+ def test_active_session (self ):
3991+ spark = SparkSession .builder \
3992+ .master ("local" ) \
3993+ .getOrCreate ()
3994+ try :
3995+ activeSession = SparkSession .getActiveSession ()
3996+ df = activeSession .createDataFrame ([(1 , 'Alice' )], ['age' , 'name' ])
3997+ self .assertEqual (df .collect (), [Row (age = 1 , name = u'Alice' )])
3998+ finally :
3999+ spark .stop ()
4000+
4001+ def test_get_active_session_when_no_active_session (self ):
4002+ active = SparkSession .getActiveSession ()
4003+ self .assertEqual (active , None )
4004+ spark = SparkSession .builder \
4005+ .master ("local" ) \
4006+ .getOrCreate ()
4007+ active = SparkSession .getActiveSession ()
4008+ self .assertEqual (active , spark )
4009+ spark .stop ()
4010+ active = SparkSession .getActiveSession ()
4011+ self .assertEqual (active , None )
4012+
4013+ def test_SparkSession (self ):
4014+ spark = SparkSession .builder \
4015+ .master ("local" ) \
4016+ .config ("some-config" , "v2" ) \
4017+ .getOrCreate ()
4018+ try :
4019+ self .assertEqual (spark .conf .get ("some-config" ), "v2" )
4020+ self .assertEqual (spark .sparkContext ._conf .get ("some-config" ), "v2" )
4021+ self .assertEqual (spark .version , spark .sparkContext .version )
4022+ spark .sql ("CREATE DATABASE test_db" )
4023+ spark .catalog .setCurrentDatabase ("test_db" )
4024+ self .assertEqual (spark .catalog .currentDatabase (), "test_db" )
4025+ spark .sql ("CREATE TABLE table1 (name STRING, age INT) USING parquet" )
4026+ self .assertEqual (spark .table ("table1" ).columns , ['name' , 'age' ])
4027+ self .assertEqual (spark .range (3 ).count (), 3 )
4028+ finally :
4029+ spark .stop ()
4030+
4031+ def test_global_default_session (self ):
4032+ spark = SparkSession .builder \
4033+ .master ("local" ) \
4034+ .getOrCreate ()
4035+ try :
4036+ self .assertEqual (SparkSession .builder .getOrCreate (), spark )
4037+ finally :
4038+ spark .stop ()
4039+
4040+ def test_default_and_active_session (self ):
4041+ spark = SparkSession .builder \
4042+ .master ("local" ) \
4043+ .getOrCreate ()
4044+ activeSession = spark ._jvm .SparkSession .getActiveSession ()
4045+ defaultSession = spark ._jvm .SparkSession .getDefaultSession ()
4046+ try :
4047+ self .assertEqual (activeSession , defaultSession )
4048+ finally :
4049+ spark .stop ()
4050+
4051+ def test_config_option_propagated_to_existing_session (self ):
4052+ session1 = SparkSession .builder \
4053+ .master ("local" ) \
4054+ .config ("spark-config1" , "a" ) \
4055+ .getOrCreate ()
4056+ self .assertEqual (session1 .conf .get ("spark-config1" ), "a" )
4057+ session2 = SparkSession .builder \
4058+ .config ("spark-config1" , "b" ) \
4059+ .getOrCreate ()
4060+ try :
4061+ self .assertEqual (session1 , session2 )
4062+ self .assertEqual (session1 .conf .get ("spark-config1" ), "b" )
4063+ finally :
4064+ session1 .stop ()
4065+
4066+ def test_new_session (self ):
4067+ session = SparkSession .builder \
4068+ .master ("local" ) \
4069+ .getOrCreate ()
4070+ newSession = session .newSession ()
4071+ try :
4072+ self .assertNotEqual (session , newSession )
4073+ finally :
4074+ session .stop ()
4075+ newSession .stop ()
4076+
4077+ def test_create_new_session_if_old_session_stopped (self ):
4078+ session = SparkSession .builder \
4079+ .master ("local" ) \
4080+ .getOrCreate ()
4081+ session .stop ()
4082+ newSession = SparkSession .builder \
4083+ .master ("local" ) \
4084+ .getOrCreate ()
4085+ try :
4086+ self .assertNotEqual (session , newSession )
4087+ finally :
4088+ newSession .stop ()
4089+
4090+ def test_active_session_with_None_and_not_None_context (self ):
4091+ from pyspark .context import SparkContext
4092+ from pyspark .conf import SparkConf
4093+ sc = None
4094+ session = None
4095+ try :
4096+ sc = SparkContext ._active_spark_context
4097+ self .assertEqual (sc , None )
4098+ activeSession = SparkSession .getActiveSession ()
4099+ self .assertEqual (activeSession , None )
4100+ sparkConf = SparkConf ()
4101+ sc = SparkContext .getOrCreate (sparkConf )
4102+ activeSession = sc ._jvm .SparkSession .getActiveSession ()
4103+ self .assertFalse (activeSession .isDefined ())
4104+ session = SparkSession (sc )
4105+ activeSession = sc ._jvm .SparkSession .getActiveSession ()
4106+ self .assertTrue (activeSession .isDefined ())
4107+ activeSession2 = SparkSession .getActiveSession ()
4108+ self .assertNotEqual (activeSession2 , None )
4109+ finally :
4110+ if session is not None :
4111+ session .stop ()
4112+ if sc is not None :
4113+ sc .stop ()
4114+
4115+
4116+ class SparkSessionTests3 (ReusedSQLTestCase ):
4117+
4118+ def test_get_active_session_after_create_dataframe (self ):
4119+ session2 = None
4120+ try :
4121+ activeSession1 = SparkSession .getActiveSession ()
4122+ session1 = self .spark
4123+ self .assertEqual (session1 , activeSession1 )
4124+ session2 = self .spark .newSession ()
4125+ activeSession2 = SparkSession .getActiveSession ()
4126+ self .assertEqual (session1 , activeSession2 )
4127+ self .assertNotEqual (session2 , activeSession2 )
4128+ session2 .createDataFrame ([(1 , 'Alice' )], ['age' , 'name' ])
4129+ activeSession3 = SparkSession .getActiveSession ()
4130+ self .assertEqual (session2 , activeSession3 )
4131+ session1 .createDataFrame ([(1 , 'Alice' )], ['age' , 'name' ])
4132+ activeSession4 = SparkSession .getActiveSession ()
4133+ self .assertEqual (session1 , activeSession4 )
4134+ finally :
4135+ if session2 is not None :
4136+ session2 .stop ()
4137+
4138+
39884139class UDFInitializationTests (unittest .TestCase ):
39894140 def tearDown (self ):
39904141 if SparkSession ._instantiatedSession is not None :
0 commit comments