From 6d6ef769c8fd843f17732f60e5410e8cee75d9ed Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Sat, 13 Nov 2021 15:41:19 -0800 Subject: [PATCH] [SPARK-37291][PYTHON][SQL] PySpark init SparkSession should copy conf to sharedState ### What changes were proposed in this pull request? When use write pyspark script like ``` conf = SparkConf().setAppName("test") sc = SparkContext(conf = conf) session = SparkSession().build().enableHiveSupport().getOrCreate() ``` It will build a session without hive support since we use a existed SparkContext and we create SparkSession use ``` SparkSession(sc) ``` This cause we loss configuration added by `config()` such as catalog implement. In scala class `SparkSession`, we create `SparkSession` with `SparkContext` and option configurations and will pass option configurations to `SharedState` then use `SharedState`'s conf create SessionState, but in pyspark, we won't pass options configuration to `SharedState`, but pass to `SessionState`, but this time `SessionState` has been initialized. So it won't support hive. In this pr, I pass option configurations to `SharedState` when first init `SparkSession`, then when init `SessionState`, this options will be passed to `SessionState` too. ### Why are the changes needed? Avoid loss configuration when build SparkSession in pyspark ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manuel tested & added UT Closes #34559 from AngersZhuuuu/SPARK-37291. Authored-by: Angerszhuuuu Signed-off-by: Dongjoon Hyun --- python/pyspark/sql/session.py | 12 ++++++++++-- python/pyspark/sql/tests/test_session.py | 23 +++++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 140262bbb705a..927554198743e 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -276,7 +276,7 @@ def getOrCreate(self) -> "SparkSession": sc = SparkContext.getOrCreate(sparkConf) # Do not update `SparkConf` for existing `SparkContext`, as it's shared # by all sessions. - session = SparkSession(sc) + session = SparkSession(sc, options=self._options) for key, value in self._options.items(): session._jsparkSession.sessionState().conf().setConfString(key, value) return session @@ -287,7 +287,12 @@ def getOrCreate(self) -> "SparkSession": _instantiatedSession: ClassVar[Optional["SparkSession"]] = None _activeSession: ClassVar[Optional["SparkSession"]] = None - def __init__(self, sparkContext: SparkContext, jsparkSession: Optional[JavaObject] = None): + def __init__( + self, + sparkContext: SparkContext, + jsparkSession: Optional[JavaObject] = None, + options: Optional[Dict[str, Any]] = None, + ): from pyspark.sql.context import SQLContext self._sc = sparkContext @@ -301,6 +306,9 @@ def __init__(self, sparkContext: SparkContext, jsparkSession: Optional[JavaObjec jsparkSession = self._jvm.SparkSession.getDefaultSession().get() else: jsparkSession = self._jvm.SparkSession(self._jsc.sc()) + if options is not None: + for key, value in options.items(): + jsparkSession.sharedState().conf().set(key, value) self._jsparkSession = jsparkSession self._jwrapped = self._jsparkSession.sqlContext() self._wrapped = SQLContext(self._sc, self, self._jwrapped) diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py index abc131180511d..eb23b68ccf498 100644 --- a/python/pyspark/sql/tests/test_session.py +++ b/python/pyspark/sql/tests/test_session.py @@ -289,6 +289,29 @@ def test_another_spark_session(self): if session2 is not None: session2.stop() + def test_create_spark_context_first_and_copy_options_to_sharedState(self): + sc = None + session = None + try: + conf = SparkConf().set("key1", "value1") + sc = SparkContext("local[4]", "SessionBuilderTests", conf=conf) + session = ( + SparkSession.builder.config("key2", "value2").enableHiveSupport().getOrCreate() + ) + + self.assertEqual(session._jsparkSession.sharedState().conf().get("key1"), "value1") + self.assertEqual(session._jsparkSession.sharedState().conf().get("key2"), "value2") + self.assertEqual( + session._jsparkSession.sharedState().conf().get("spark.sql.catalogImplementation"), + "hive", + ) + self.assertEqual(session.sparkContext, sc) + finally: + if session is not None: + session.stop() + if sc is not None: + sc.stop() + class SparkExtensionsTest(unittest.TestCase): # These tests are separate because it uses 'spark.sql.extensions' which is