Skip to content

Commit 506cb1d

Browse files
author
Gaetan RACIC
committed
Backport PR [SPARK-25003][PYSPARK] to use SessionExtensions in Pyspark
apache#21990
1 parent f9cdef9 commit 506cb1d

File tree

2 files changed

+80
-18
lines changed

2 files changed

+80
-18
lines changed

python/pyspark/sql/tests.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3711,6 +3711,48 @@ def test_query_execution_listener_on_collect_with_arrow(self):
37113711
"The callback from the query execution listener should be called after 'toPandas'")
37123712

37133713

3714+
class SparkExtensionsTest(unittest.TestCase):
3715+
# These tests are separate because it uses 'spark.sql.extensions' which is
3716+
# static and immutable. This can't be set or unset, for example, via `spark.conf`.
3717+
3718+
@classmethod
3719+
def setUpClass(cls):
3720+
import glob
3721+
from pyspark.find_spark_home import _find_spark_home
3722+
3723+
SPARK_HOME = _find_spark_home()
3724+
filename_pattern = (
3725+
"sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
3726+
"SparkSessionExtensionSuite.class")
3727+
if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)):
3728+
raise unittest.SkipTest(
3729+
"'org.apache.spark.sql.SparkSessionExtensionSuite' is not "
3730+
"available. Will skip the related tests.")
3731+
3732+
# Note that 'spark.sql.extensions' is a static immutable configuration.
3733+
cls.spark = SparkSession.builder \
3734+
.master("local[4]") \
3735+
.appName(cls.__name__) \
3736+
.config(
3737+
"spark.sql.extensions",
3738+
"org.apache.spark.sql.MyExtensions") \
3739+
.getOrCreate()
3740+
3741+
@classmethod
3742+
def tearDownClass(cls):
3743+
cls.spark.stop()
3744+
3745+
def test_use_custom_class_for_extensions(self):
3746+
self.assertTrue(
3747+
self.spark._jsparkSession.sessionState().planner().strategies().contains(
3748+
self.spark._jvm.org.apache.spark.sql.MySparkStrategy(self.spark._jsparkSession)),
3749+
"MySparkStrategy not found in active planner strategies")
3750+
self.assertTrue(
3751+
self.spark._jsparkSession.sessionState().analyzer().extendedResolutionRules().contains(
3752+
self.spark._jvm.org.apache.spark.sql.MyRule(self.spark._jsparkSession)),
3753+
"MyRule not found in extended resolution rules")
3754+
3755+
37143756
class SparkSessionTests(PySparkTestCase):
37153757

37163758
# This test is separate because it's closely related with session's start and stop.

sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,17 @@ class SparkSession private(
8484
// The call site where this SparkSession was constructed.
8585
private val creationSite: CallSite = Utils.getCallSite()
8686

87+
/**
88+
* Constructor used in Pyspark. Contains explicit application of Spark Session Extensions
89+
* which otherwise only occurs during getOrCreate. We cannot add this to the default constructor
90+
* since that would cause every new session to reinvoke Spark Session Extensions on the currently
91+
* running extensions.
92+
*/
8793
private[sql] def this(sc: SparkContext) {
88-
this(sc, None, None, new SparkSessionExtensions)
94+
this(sc, None, None,
95+
SparkSession.applyExtensions(
96+
sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS),
97+
new SparkSessionExtensions))
8998
}
9099

91100
sparkContext.assertNotStopped()
@@ -936,23 +945,9 @@ object SparkSession extends Logging {
936945
// Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions.
937946
}
938947

939-
// Initialize extensions if the user has defined a configurator class.
940-
val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
941-
if (extensionConfOption.isDefined) {
942-
val extensionConfClassName = extensionConfOption.get
943-
try {
944-
val extensionConfClass = Utils.classForName(extensionConfClassName)
945-
val extensionConf = extensionConfClass.newInstance()
946-
.asInstanceOf[SparkSessionExtensions => Unit]
947-
extensionConf(extensions)
948-
} catch {
949-
// Ignore the error if we cannot find the class or when the class has the wrong type.
950-
case e @ (_: ClassCastException |
951-
_: ClassNotFoundException |
952-
_: NoClassDefFoundError) =>
953-
logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
954-
}
955-
}
948+
applyExtensions(
949+
sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS),
950+
extensions)
956951

957952
session = new SparkSession(sparkContext, None, None, extensions)
958953
options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) }
@@ -1137,4 +1132,29 @@ object SparkSession extends Logging {
11371132
SparkSession.clearDefaultSession()
11381133
}
11391134
}
1135+
1136+
/**
1137+
* Initialize extensions for given extension classname. This class will be applied to the
1138+
* extensions passed into this function.
1139+
*/
1140+
private def applyExtensions(
1141+
extensionOption: Option[String],
1142+
extensions: SparkSessionExtensions): SparkSessionExtensions = {
1143+
if (extensionOption.isDefined) {
1144+
val extensionConfClassName = extensionOption.get
1145+
try {
1146+
val extensionConfClass = Utils.classForName(extensionConfClassName)
1147+
val extensionConf = extensionConfClass.newInstance()
1148+
.asInstanceOf[SparkSessionExtensions => Unit]
1149+
extensionConf(extensions)
1150+
} catch {
1151+
// Ignore the error if we cannot find the class or when the class has the wrong type.
1152+
case e@(_: ClassCastException |
1153+
_: ClassNotFoundException |
1154+
_: NoClassDefFoundError) =>
1155+
logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
1156+
}
1157+
}
1158+
extensions
1159+
}
11401160
}

0 commit comments

Comments
 (0)