Skip to content

Commit

Permalink
[SPARK-37291][PYTHON][SQL] PySpark init SparkSession should copy conf…
Browse files Browse the repository at this point in the history
… to sharedState

### What changes were proposed in this pull request?
When use write pyspark script like
```
conf = SparkConf().setAppName("test")
sc = SparkContext(conf = conf)
session = SparkSession().build().enableHiveSupport().getOrCreate()
```

It will build a session without hive support since we use a existed SparkContext and we create SparkSession use
```
SparkSession(sc)
```
This cause we loss configuration added by `config()` such as catalog implement.

In scala class `SparkSession`, we create `SparkSession` with `SparkContext` and option configurations and will pass option configurations to `SharedState` then use `SharedState`'s conf create SessionState, but in pyspark, we won't pass options configuration to `SharedState`, but pass to `SessionState`, but this time `SessionState` has been initialized.  So it won't support hive.

In this pr, I pass option configurations to `SharedState` when first init `SparkSession`, then when  init `SessionState`, this options will be passed to `SessionState` too.

### Why are the changes needed?
Avoid loss configuration when build SparkSession in pyspark

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Manuel tested & added UT

Closes #34559 from AngersZhuuuu/SPARK-37291.

Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
AngersZhuuuu authored and dongjoon-hyun committed Nov 13, 2021
1 parent 6426dcf commit 6d6ef76
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
12 changes: 10 additions & 2 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def getOrCreate(self) -> "SparkSession":
sc = SparkContext.getOrCreate(sparkConf)
# Do not update `SparkConf` for existing `SparkContext`, as it's shared
# by all sessions.
session = SparkSession(sc)
session = SparkSession(sc, options=self._options)
for key, value in self._options.items():
session._jsparkSession.sessionState().conf().setConfString(key, value)
return session
Expand All @@ -287,7 +287,12 @@ def getOrCreate(self) -> "SparkSession":
_instantiatedSession: ClassVar[Optional["SparkSession"]] = None
_activeSession: ClassVar[Optional["SparkSession"]] = None

def __init__(self, sparkContext: SparkContext, jsparkSession: Optional[JavaObject] = None):
def __init__(
self,
sparkContext: SparkContext,
jsparkSession: Optional[JavaObject] = None,
options: Optional[Dict[str, Any]] = None,
):
from pyspark.sql.context import SQLContext

self._sc = sparkContext
Expand All @@ -301,6 +306,9 @@ def __init__(self, sparkContext: SparkContext, jsparkSession: Optional[JavaObjec
jsparkSession = self._jvm.SparkSession.getDefaultSession().get()
else:
jsparkSession = self._jvm.SparkSession(self._jsc.sc())
if options is not None:
for key, value in options.items():
jsparkSession.sharedState().conf().set(key, value)
self._jsparkSession = jsparkSession
self._jwrapped = self._jsparkSession.sqlContext()
self._wrapped = SQLContext(self._sc, self, self._jwrapped)
Expand Down
23 changes: 23 additions & 0 deletions python/pyspark/sql/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,29 @@ def test_another_spark_session(self):
if session2 is not None:
session2.stop()

def test_create_spark_context_first_and_copy_options_to_sharedState(self):
sc = None
session = None
try:
conf = SparkConf().set("key1", "value1")
sc = SparkContext("local[4]", "SessionBuilderTests", conf=conf)
session = (
SparkSession.builder.config("key2", "value2").enableHiveSupport().getOrCreate()
)

self.assertEqual(session._jsparkSession.sharedState().conf().get("key1"), "value1")
self.assertEqual(session._jsparkSession.sharedState().conf().get("key2"), "value2")
self.assertEqual(
session._jsparkSession.sharedState().conf().get("spark.sql.catalogImplementation"),
"hive",
)
self.assertEqual(session.sparkContext, sc)
finally:
if session is not None:
session.stop()
if sc is not None:
sc.stop()


class SparkExtensionsTest(unittest.TestCase):
# These tests are separate because it uses 'spark.sql.extensions' which is
Expand Down

0 comments on commit 6d6ef76

Please sign in to comment.