-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-25255][PYTHON]Add getActiveSession to SparkSession in PySpark #22295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5bfe614
9048a36
221ea01
6f89066
c223dd2
091b1d5
1cda049
69b29e9
d8fef1c
59ad7a7
f2949f1
7c6d2d5
fb47432
94e3db0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -192,6 +192,7 @@ def getOrCreate(self): | |
| """A class attribute having a :class:`Builder` to construct :class:`SparkSession` instances""" | ||
|
|
||
| _instantiatedSession = None | ||
| _activeSession = None | ||
|
|
||
| @ignore_unicode_prefix | ||
| def __init__(self, sparkContext, jsparkSession=None): | ||
|
|
@@ -233,7 +234,9 @@ def __init__(self, sparkContext, jsparkSession=None): | |
| if SparkSession._instantiatedSession is None \ | ||
| or SparkSession._instantiatedSession._sc._jsc is None: | ||
| SparkSession._instantiatedSession = self | ||
| SparkSession._activeSession = self | ||
| self._jvm.SparkSession.setDefaultSession(self._jsparkSession) | ||
| self._jvm.SparkSession.setActiveSession(self._jsparkSession) | ||
|
||
|
|
||
| def _repr_html_(self): | ||
| return """ | ||
|
|
@@ -255,6 +258,29 @@ def newSession(self): | |
| """ | ||
| return self.__class__(self._sc, self._jsparkSession.newSession()) | ||
|
|
||
| @classmethod | ||
| @since(3.0) | ||
|
||
| def getActiveSession(cls): | ||
| """ | ||
| Returns the active SparkSession for the current thread, returned by the builder. | ||
| >>> s = SparkSession.getActiveSession() | ||
| >>> l = [('Alice', 1)] | ||
| >>> rdd = s.sparkContext.parallelize(l) | ||
| >>> df = s.createDataFrame(rdd, ['name', 'age']) | ||
| >>> df.select("age").collect() | ||
| [Row(age=1)] | ||
| """ | ||
| from pyspark import SparkContext | ||
| sc = SparkContext._active_spark_context | ||
| if sc is None: | ||
| return None | ||
| else: | ||
| if sc._jvm.SparkSession.getActiveSession().isDefined(): | ||
| SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get()) | ||
| return SparkSession._activeSession | ||
| else: | ||
| return None | ||
|
|
||
| @property | ||
| @since(2.0) | ||
| def sparkContext(self): | ||
|
|
@@ -671,6 +697,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr | |
| ... | ||
| Py4JJavaError: ... | ||
| """ | ||
| SparkSession._activeSession = self | ||
| self._jvm.SparkSession.setActiveSession(self._jsparkSession) | ||
| if isinstance(data, DataFrame): | ||
| raise TypeError("data is already a DataFrame") | ||
|
|
||
|
|
@@ -826,7 +854,9 @@ def stop(self): | |
| self._sc.stop() | ||
| # We should clean the default session up. See SPARK-23228. | ||
| self._jvm.SparkSession.clearDefaultSession() | ||
| self._jvm.SparkSession.clearActiveSession() | ||
| SparkSession._instantiatedSession = None | ||
| SparkSession._activeSession = None | ||
|
|
||
| @since(2.0) | ||
| def __enter__(self): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3863,6 +3863,157 @@ def test_jvm_default_session_already_set(self): | |
| spark.stop() | ||
|
|
||
|
|
||
| class SparkSessionTests2(unittest.TestCase): | ||
|
|
||
| def test_active_session(self): | ||
| spark = SparkSession.builder \ | ||
| .master("local") \ | ||
| .getOrCreate() | ||
| try: | ||
| activeSession = SparkSession.getActiveSession() | ||
| df = activeSession.createDataFrame([(1, 'Alice')], ['age', 'name']) | ||
| self.assertEqual(df.collect(), [Row(age=1, name=u'Alice')]) | ||
| finally: | ||
| spark.stop() | ||
|
|
||
| def test_get_active_session_when_no_active_session(self): | ||
| active = SparkSession.getActiveSession() | ||
| self.assertEqual(active, None) | ||
| spark = SparkSession.builder \ | ||
| .master("local") \ | ||
| .getOrCreate() | ||
| active = SparkSession.getActiveSession() | ||
| self.assertEqual(active, spark) | ||
| spark.stop() | ||
| active = SparkSession.getActiveSession() | ||
| self.assertEqual(active, None) | ||
|
||
|
|
||
| def test_SparkSession(self): | ||
| spark = SparkSession.builder \ | ||
| .master("local") \ | ||
| .config("some-config", "v2") \ | ||
| .getOrCreate() | ||
| try: | ||
| self.assertEqual(spark.conf.get("some-config"), "v2") | ||
| self.assertEqual(spark.sparkContext._conf.get("some-config"), "v2") | ||
| self.assertEqual(spark.version, spark.sparkContext.version) | ||
| spark.sql("CREATE DATABASE test_db") | ||
| spark.catalog.setCurrentDatabase("test_db") | ||
| self.assertEqual(spark.catalog.currentDatabase(), "test_db") | ||
| spark.sql("CREATE TABLE table1 (name STRING, age INT) USING parquet") | ||
| self.assertEqual(spark.table("table1").columns, ['name', 'age']) | ||
| self.assertEqual(spark.range(3).count(), 3) | ||
| finally: | ||
| spark.stop() | ||
|
|
||
| def test_global_default_session(self): | ||
| spark = SparkSession.builder \ | ||
| .master("local") \ | ||
| .getOrCreate() | ||
| try: | ||
| self.assertEqual(SparkSession.builder.getOrCreate(), spark) | ||
| finally: | ||
| spark.stop() | ||
|
|
||
| def test_default_and_active_session(self): | ||
| spark = SparkSession.builder \ | ||
| .master("local") \ | ||
| .getOrCreate() | ||
| activeSession = spark._jvm.SparkSession.getActiveSession() | ||
| defaultSession = spark._jvm.SparkSession.getDefaultSession() | ||
| try: | ||
| self.assertEqual(activeSession, defaultSession) | ||
| finally: | ||
| spark.stop() | ||
|
|
||
| def test_config_option_propagated_to_existing_session(self): | ||
| session1 = SparkSession.builder \ | ||
| .master("local") \ | ||
| .config("spark-config1", "a") \ | ||
| .getOrCreate() | ||
| self.assertEqual(session1.conf.get("spark-config1"), "a") | ||
| session2 = SparkSession.builder \ | ||
| .config("spark-config1", "b") \ | ||
| .getOrCreate() | ||
| try: | ||
| self.assertEqual(session1, session2) | ||
| self.assertEqual(session1.conf.get("spark-config1"), "b") | ||
| finally: | ||
| session1.stop() | ||
|
|
||
| def test_new_session(self): | ||
| session = SparkSession.builder \ | ||
| .master("local") \ | ||
| .getOrCreate() | ||
| newSession = session.newSession() | ||
| try: | ||
| self.assertNotEqual(session, newSession) | ||
| finally: | ||
| session.stop() | ||
| newSession.stop() | ||
|
|
||
| def test_create_new_session_if_old_session_stopped(self): | ||
| session = SparkSession.builder \ | ||
| .master("local") \ | ||
| .getOrCreate() | ||
| session.stop() | ||
| newSession = SparkSession.builder \ | ||
| .master("local") \ | ||
| .getOrCreate() | ||
| try: | ||
| self.assertNotEqual(session, newSession) | ||
| finally: | ||
| newSession.stop() | ||
|
|
||
| def test_active_session_with_None_and_not_None_context(self): | ||
| from pyspark.context import SparkContext | ||
| from pyspark.conf import SparkConf | ||
| sc = None | ||
| session = None | ||
| try: | ||
| sc = SparkContext._active_spark_context | ||
| self.assertEqual(sc, None) | ||
| activeSession = SparkSession.getActiveSession() | ||
| self.assertEqual(activeSession, None) | ||
| sparkConf = SparkConf() | ||
| sc = SparkContext.getOrCreate(sparkConf) | ||
| activeSession = sc._jvm.SparkSession.getActiveSession() | ||
| self.assertFalse(activeSession.isDefined()) | ||
| session = SparkSession(sc) | ||
| activeSession = sc._jvm.SparkSession.getActiveSession() | ||
| self.assertTrue(activeSession.isDefined()) | ||
| activeSession2 = SparkSession.getActiveSession() | ||
| self.assertNotEqual(activeSession2, None) | ||
| finally: | ||
| if session is not None: | ||
| session.stop() | ||
| if sc is not None: | ||
| sc.stop() | ||
|
|
||
|
|
||
| class SparkSessionTests3(ReusedSQLTestCase): | ||
|
|
||
| def test_get_active_session_after_create_dataframe(self): | ||
| session2 = None | ||
| try: | ||
| activeSession1 = SparkSession.getActiveSession() | ||
| session1 = self.spark | ||
| self.assertEqual(session1, activeSession1) | ||
| session2 = self.spark.newSession() | ||
| activeSession2 = SparkSession.getActiveSession() | ||
| self.assertEqual(session1, activeSession2) | ||
| self.assertNotEqual(session2, activeSession2) | ||
| session2.createDataFrame([(1, 'Alice')], ['age', 'name']) | ||
| activeSession3 = SparkSession.getActiveSession() | ||
| self.assertEqual(session2, activeSession3) | ||
| session1.createDataFrame([(1, 'Alice')], ['age', 'name']) | ||
| activeSession4 = SparkSession.getActiveSession() | ||
| self.assertEqual(session1, activeSession4) | ||
| finally: | ||
| if session2 is not None: | ||
| session2.stop() | ||
|
|
||
|
|
||
| class UDFInitializationTests(unittest.TestCase): | ||
| def tearDown(self): | ||
| if SparkSession._instantiatedSession is not None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching this! Filed a follow up https://issues.apache.org/jira/browse/SPARK-25432
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks you very much for your comments.
I have a question here. In stop() method, shall we clear the activeSession too? Currently, it has
Do I need to add the following?
To test for getActiveSession when there is no active session, I am thinking of adding
The test didn't pass because in stop(), the active session is not cleared.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that sounds like the right approach and I think we need that.