Skip to content
Closed
30 changes: 30 additions & 0 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def getOrCreate(self):
"""A class attribute having a :class:`Builder` to construct :class:`SparkSession` instances"""

_instantiatedSession = None
_activeSession = None

@ignore_unicode_prefix
def __init__(self, sparkContext, jsparkSession=None):
Expand Down Expand Up @@ -233,7 +234,9 @@ def __init__(self, sparkContext, jsparkSession=None):
if SparkSession._instantiatedSession is None \
or SparkSession._instantiatedSession._sc._jsc is None:
SparkSession._instantiatedSession = self
SparkSession._activeSession = self
self._jvm.SparkSession.setDefaultSession(self._jsparkSession)
self._jvm.SparkSession.setActiveSession(self._jsparkSession)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this! Filed a follow up https://issues.apache.org/jira/browse/SPARK-25432

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks you very much for your comments.
I have a question here. In stop() method, shall we clear the activeSession too? Currently, it has

    def stop(self):
        """Stop the underlying :class:`SparkContext`.
        """
        self._jvm.SparkSession.clearDefaultSession()
        SparkSession._instantiatedSession = None

Do I need to add the following?

      self._jvm.SparkSession.clearActiveSession()

To test for getActiveSession when there is no active session, I am thinking of adding

    def test_get_active_session_when_no_active_session(self):
        spark = SparkSession.builder \
            .master("local") \
            .getOrCreate()
        spark.stop()
        active = spark.getActiveSession()
        self.assertEqual(active, None)

The test didn't pass because in stop(), the active session is not cleared.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that sounds like the right approach and I think we need that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@huaxingao, can you check if the active session is set? for instance when we createDataFrame? From a cursory look, we are not setting it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon Seems to me that active session is set OK in the __init__. When createDataFrame, we already have a session, and the active session is already set in the __init__.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When createDataFrame, we already have a session

but wouldn't we not set the active session properly if session A sets an active session in __init__, and then session B sets an active session in __init__, and then session A calls createDataFrame ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon Do you mean something like this:

    def test_two_spark_session(self):
        session1 = None
        session2 = None
        try:
            session1 = SparkSession.builder.config("key1", "value1").getOrCreate()
            session2 = SparkSession.builder.config("key2", "value2").getOrCreate()
            self.assertEqual(session1, session2)

            df = session1.createDataFrame([(1, 'Alice')], ['age', 'name'])
            self.assertEqual(df.collect(), [Row(age=1, name=u'Alice')])
            activeSession1 = session1.getActiveSession()
            activeSession2 = session2.getActiveSession()
            self.assertEqual(activeSession1, activeSession1)

        finally:
            if session1 is not None:
                session1.stop()
            if session2 is not None:
                session2.stop()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simialr. I was expecting something like:

session1 = SparkSession.builder.config("key1", "value1").getOrCreate()
session2 = SparkSession.builder.config("key2", "value2").getOrCreate()

assert(session2 == SparkSession.getActiveSession())

session1.createDataFrame([(1, 'Alice')], ['age', 'name'])

assert(session1 == SparkSession.getActiveSession())

does this work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So @HyukjinKwon in this code session1 and session2 are already equal:

Welcome to
____ __
/ / ___ / /
\ / _ / _ `/ __/ '/
/
/ .
_/_,// //_\ version 2.3.1
/
/

Using Python version 3.6.5 (default, Apr 29 2018 16:14:56)
SparkSession available as 'spark'.

session1 = SparkSession.builder.config("key1", "value1").getOrCreate()
session2 = SparkSession.builder.config("key2", "value2").getOrCreate()
session1
<pyspark.sql.session.SparkSession object at 0x7ff6d4843b00>
session2
<pyspark.sql.session.SparkSession object at 0x7ff6d4843b00>
session1 == session2
True

That being said the possibility of having multiple Spark session in Python is doable you manually have to call the init e.g.:

session3 = SparkSession(sc)
session3
<pyspark.sql.session.SparkSession object at 0x7ff6d3dbd160>

And supporting that is reasonable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're going to support this we should have test for it, or if we aren't going to support this right now we should document the behaviour.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, okay. I had to be explicit. I meant:

scala> import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.SparkSession

scala> SparkSession.getActiveSession
res0: Option[org.apache.spark.sql.SparkSession] = Some(org.apache.spark.sql.SparkSession@3ef4a8fb)

scala> val session1 = spark
session1: org.apache.spark.sql.SparkSession = org.apache.spark.sql.SparkSession@3ef4a8fb

scala> val session2 = spark.newSession()
session2: org.apache.spark.sql.SparkSession = org.apache.spark.sql.SparkSession@4b74a4d

scala> SparkSession.getActiveSession
res1: Option[org.apache.spark.sql.SparkSession] = Some(org.apache.spark.sql.SparkSession@3ef4a8fb)

scala> session2.createDataFrame(Seq(Tuple1(1)))
res2: org.apache.spark.sql.DataFrame = [_1: int]

scala> SparkSession.getActiveSession
res3: Option[org.apache.spark.sql.SparkSession] = Some(org.apache.spark.sql.SparkSession@4b74a4d)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@holdenk @HyukjinKwon
Thanks for the comments. I looked the scala code, it setActiveSession in createDataFrame.

  def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
    SparkSession.setActiveSession(this)
    ...
  }

I will do the same for python.

def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
        SparkSession._activeSession = self
        self._jvm.SparkSession.setActiveSession(self._jsparkSession)

Will also add a test


def _repr_html_(self):
return """
Expand All @@ -255,6 +258,29 @@ def newSession(self):
"""
return self.__class__(self._sc, self._jsparkSession.newSession())

@classmethod
@since(3.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's change this to 2.5

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon are you OK to mark this comment as resolved since we're now targeting 3.0?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, at that time, 2.5 was targeted. Now 3.0 is targeted per 9bf397c

def getActiveSession(cls):
"""
Returns the active SparkSession for the current thread, returned by the builder.
>>> s = SparkSession.getActiveSession()
>>> l = [('Alice', 1)]
>>> rdd = s.sparkContext.parallelize(l)
>>> df = s.createDataFrame(rdd, ['name', 'age'])
>>> df.select("age").collect()
[Row(age=1)]
"""
from pyspark import SparkContext
sc = SparkContext._active_spark_context
if sc is None:
return None
else:
if sc._jvm.SparkSession.getActiveSession().isDefined():
SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get())
return SparkSession._activeSession
else:
return None

@property
@since(2.0)
def sparkContext(self):
Expand Down Expand Up @@ -671,6 +697,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
...
Py4JJavaError: ...
"""
SparkSession._activeSession = self
self._jvm.SparkSession.setActiveSession(self._jsparkSession)
if isinstance(data, DataFrame):
raise TypeError("data is already a DataFrame")

Expand Down Expand Up @@ -826,7 +854,9 @@ def stop(self):
self._sc.stop()
# We should clean the default session up. See SPARK-23228.
self._jvm.SparkSession.clearDefaultSession()
self._jvm.SparkSession.clearActiveSession()
SparkSession._instantiatedSession = None
SparkSession._activeSession = None

@since(2.0)
def __enter__(self):
Expand Down
151 changes: 151 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3863,6 +3863,157 @@ def test_jvm_default_session_already_set(self):
spark.stop()


class SparkSessionTests2(unittest.TestCase):

def test_active_session(self):
spark = SparkSession.builder \
.master("local") \
.getOrCreate()
try:
activeSession = SparkSession.getActiveSession()
df = activeSession.createDataFrame([(1, 'Alice')], ['age', 'name'])
self.assertEqual(df.collect(), [Row(age=1, name=u'Alice')])
finally:
spark.stop()

def test_get_active_session_when_no_active_session(self):
active = SparkSession.getActiveSession()
self.assertEqual(active, None)
spark = SparkSession.builder \
.master("local") \
.getOrCreate()
active = SparkSession.getActiveSession()
self.assertEqual(active, spark)
spark.stop()
active = SparkSession.getActiveSession()
self.assertEqual(active, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the change for how we construct the SparkSession can we add a test that makes sure we do whatever we decide to with the SparkContext?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @holdenk
I will add a test for the above comment and also add a test for your comment regarding

self._jvm.SparkSession.setActiveSession(self._jsparkSession)


def test_SparkSession(self):
spark = SparkSession.builder \
.master("local") \
.config("some-config", "v2") \
.getOrCreate()
try:
self.assertEqual(spark.conf.get("some-config"), "v2")
self.assertEqual(spark.sparkContext._conf.get("some-config"), "v2")
self.assertEqual(spark.version, spark.sparkContext.version)
spark.sql("CREATE DATABASE test_db")
spark.catalog.setCurrentDatabase("test_db")
self.assertEqual(spark.catalog.currentDatabase(), "test_db")
spark.sql("CREATE TABLE table1 (name STRING, age INT) USING parquet")
self.assertEqual(spark.table("table1").columns, ['name', 'age'])
self.assertEqual(spark.range(3).count(), 3)
finally:
spark.stop()

def test_global_default_session(self):
spark = SparkSession.builder \
.master("local") \
.getOrCreate()
try:
self.assertEqual(SparkSession.builder.getOrCreate(), spark)
finally:
spark.stop()

def test_default_and_active_session(self):
spark = SparkSession.builder \
.master("local") \
.getOrCreate()
activeSession = spark._jvm.SparkSession.getActiveSession()
defaultSession = spark._jvm.SparkSession.getDefaultSession()
try:
self.assertEqual(activeSession, defaultSession)
finally:
spark.stop()

def test_config_option_propagated_to_existing_session(self):
session1 = SparkSession.builder \
.master("local") \
.config("spark-config1", "a") \
.getOrCreate()
self.assertEqual(session1.conf.get("spark-config1"), "a")
session2 = SparkSession.builder \
.config("spark-config1", "b") \
.getOrCreate()
try:
self.assertEqual(session1, session2)
self.assertEqual(session1.conf.get("spark-config1"), "b")
finally:
session1.stop()

def test_new_session(self):
session = SparkSession.builder \
.master("local") \
.getOrCreate()
newSession = session.newSession()
try:
self.assertNotEqual(session, newSession)
finally:
session.stop()
newSession.stop()

def test_create_new_session_if_old_session_stopped(self):
session = SparkSession.builder \
.master("local") \
.getOrCreate()
session.stop()
newSession = SparkSession.builder \
.master("local") \
.getOrCreate()
try:
self.assertNotEqual(session, newSession)
finally:
newSession.stop()

def test_active_session_with_None_and_not_None_context(self):
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
sc = None
session = None
try:
sc = SparkContext._active_spark_context
self.assertEqual(sc, None)
activeSession = SparkSession.getActiveSession()
self.assertEqual(activeSession, None)
sparkConf = SparkConf()
sc = SparkContext.getOrCreate(sparkConf)
activeSession = sc._jvm.SparkSession.getActiveSession()
self.assertFalse(activeSession.isDefined())
session = SparkSession(sc)
activeSession = sc._jvm.SparkSession.getActiveSession()
self.assertTrue(activeSession.isDefined())
activeSession2 = SparkSession.getActiveSession()
self.assertNotEqual(activeSession2, None)
finally:
if session is not None:
session.stop()
if sc is not None:
sc.stop()


class SparkSessionTests3(ReusedSQLTestCase):

def test_get_active_session_after_create_dataframe(self):
session2 = None
try:
activeSession1 = SparkSession.getActiveSession()
session1 = self.spark
self.assertEqual(session1, activeSession1)
session2 = self.spark.newSession()
activeSession2 = SparkSession.getActiveSession()
self.assertEqual(session1, activeSession2)
self.assertNotEqual(session2, activeSession2)
session2.createDataFrame([(1, 'Alice')], ['age', 'name'])
activeSession3 = SparkSession.getActiveSession()
self.assertEqual(session2, activeSession3)
session1.createDataFrame([(1, 'Alice')], ['age', 'name'])
activeSession4 = SparkSession.getActiveSession()
self.assertEqual(session1, activeSession4)
finally:
if session2 is not None:
session2.stop()


class UDFInitializationTests(unittest.TestCase):
def tearDown(self):
if SparkSession._instantiatedSession is not None:
Expand Down