@@ -84,8 +84,17 @@ class SparkSession private(
8484 // The call site where this SparkSession was constructed.
8585 private val creationSite : CallSite = Utils .getCallSite()
8686
87+ /**
88+ * Constructor used in Pyspark. Contains explicit application of Spark Session Extensions
89+ * which otherwise only occurs during getOrCreate. We cannot add this to the default constructor
90+ * since that would cause every new session to reinvoke Spark Session Extensions on the currently
91+ * running extensions.
92+ */
8793 private [sql] def this (sc : SparkContext ) {
88- this (sc, None , None , new SparkSessionExtensions )
94+ this (sc, None , None ,
95+ SparkSession .applyExtensions(
96+ sc.getConf.get(StaticSQLConf .SPARK_SESSION_EXTENSIONS ),
97+ new SparkSessionExtensions ))
8998 }
9099
91100 sparkContext.assertNotStopped()
@@ -936,23 +945,9 @@ object SparkSession extends Logging {
936945 // Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions.
937946 }
938947
939- // Initialize extensions if the user has defined a configurator class.
940- val extensionConfOption = sparkContext.conf.get(StaticSQLConf .SPARK_SESSION_EXTENSIONS )
941- if (extensionConfOption.isDefined) {
942- val extensionConfClassName = extensionConfOption.get
943- try {
944- val extensionConfClass = Utils .classForName(extensionConfClassName)
945- val extensionConf = extensionConfClass.newInstance()
946- .asInstanceOf [SparkSessionExtensions => Unit ]
947- extensionConf(extensions)
948- } catch {
949- // Ignore the error if we cannot find the class or when the class has the wrong type.
950- case e @ (_ : ClassCastException |
951- _ : ClassNotFoundException |
952- _ : NoClassDefFoundError ) =>
953- logWarning(s " Cannot use $extensionConfClassName to configure session extensions. " , e)
954- }
955- }
948+ applyExtensions(
949+ sparkContext.getConf.get(StaticSQLConf .SPARK_SESSION_EXTENSIONS ),
950+ extensions)
956951
957952 session = new SparkSession (sparkContext, None , None , extensions)
958953 options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) }
@@ -1137,4 +1132,29 @@ object SparkSession extends Logging {
11371132 SparkSession .clearDefaultSession()
11381133 }
11391134 }
1135+
1136+ /**
1137+ * Initialize extensions for given extension classname. This class will be applied to the
1138+ * extensions passed into this function.
1139+ */
1140+ private def applyExtensions (
1141+ extensionOption : Option [String ],
1142+ extensions : SparkSessionExtensions ): SparkSessionExtensions = {
1143+ if (extensionOption.isDefined) {
1144+ val extensionConfClassName = extensionOption.get
1145+ try {
1146+ val extensionConfClass = Utils .classForName(extensionConfClassName)
1147+ val extensionConf = extensionConfClass.newInstance()
1148+ .asInstanceOf [SparkSessionExtensions => Unit ]
1149+ extensionConf(extensions)
1150+ } catch {
1151+ // Ignore the error if we cannot find the class or when the class has the wrong type.
1152+ case e@ (_ : ClassCastException |
1153+ _ : ClassNotFoundException |
1154+ _ : NoClassDefFoundError ) =>
1155+ logWarning(s " Cannot use $extensionConfClassName to configure session extensions. " , e)
1156+ }
1157+ }
1158+ extensions
1159+ }
11401160}
0 commit comments