Skip to content

Commit 67d9772

Browse files
SPARK-25003: Add helper methods to create new Extensions from Conf
Previously the only way to add extensions to the session was via the getOrCreate method of the SparkSession Builder. To facilitate non-scala Session creation we add a new constructor which takes in just the context and Extensions. Then we also add a new Extensions constructor which given a SparkConf generates an Extensions object with user config already applied.
1 parent def4f3e commit 67d9772

File tree

3 files changed

+45
-20
lines changed

3 files changed

+45
-20
lines changed

python/pyspark/sql/session.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,17 @@ def __init__(self, sparkContext, jsparkSession=None):
212212
self._sc = sparkContext
213213
self._jsc = self._sc._jsc
214214
self._jvm = self._sc._jvm
215+
215216
if jsparkSession is None:
216217
if self._jvm.SparkSession.getDefaultSession().isDefined() \
217218
and not self._jvm.SparkSession.getDefaultSession().get() \
218219
.sparkContext().isStopped():
219220
jsparkSession = self._jvm.SparkSession.getDefaultSession().get()
220221
else:
221-
jsparkSession = self._jvm.SparkSession.builder() \
222-
.sparkContext(self._jsc.sc()) \
223-
.getOrCreate()
222+
extensions = self._sc._jvm.org.apache.spark.sql\
223+
.SparkSessionExtensions(self._jsc.getConf())
224+
jsparkSession = self._jvm.SparkSession(self._jsc.sc(), extensions)
225+
224226
self._jsparkSession = jsparkSession
225227
self._jwrapped = self._jsparkSession.sqlContext()
226228
self._wrapped = SQLContext(self._sc, self, self._jwrapped)

sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ class SparkSession private(
8888
this(sc, None, None, new SparkSessionExtensions)
8989
}
9090

91+
private[sql] def this(sc: SparkContext, extensions: SparkSessionExtensions) {
92+
this(sc, None, None, extensions)
93+
}
94+
9195
sparkContext.assertNotStopped()
9296

9397
// If there is no active SparkSession, uses the default SQL conf. Otherwise, use the session's.
@@ -935,23 +939,7 @@ object SparkSession extends Logging {
935939
// Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions.
936940
}
937941

938-
// Initialize extensions if the user has defined a configurator class.
939-
val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
940-
if (extensionConfOption.isDefined) {
941-
val extensionConfClassName = extensionConfOption.get
942-
try {
943-
val extensionConfClass = Utils.classForName(extensionConfClassName)
944-
val extensionConf = extensionConfClass.newInstance()
945-
.asInstanceOf[SparkSessionExtensions => Unit]
946-
extensionConf(extensions)
947-
} catch {
948-
// Ignore the error if we cannot find the class or when the class has the wrong type.
949-
case e @ (_: ClassCastException |
950-
_: ClassNotFoundException |
951-
_: NoClassDefFoundError) =>
952-
logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
953-
}
954-
}
942+
SparkSessionExtensions.applyExtensionsFromConf(sparkContext.conf, extensions)
955943

956944
session = new SparkSession(sparkContext, None, None, extensions)
957945
options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) }

sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@ package org.apache.spark.sql
1919

2020
import scala.collection.mutable
2121

22+
import org.apache.spark.SparkConf
2223
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
24+
import org.apache.spark.internal.Logging
2325
import org.apache.spark.sql.catalyst.parser.ParserInterface
2426
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2527
import org.apache.spark.sql.catalyst.rules.Rule
28+
import org.apache.spark.sql.internal.StaticSQLConf
29+
import org.apache.spark.util.Utils
2630

2731
/**
2832
* :: Experimental ::
@@ -66,6 +70,11 @@ class SparkSessionExtensions {
6670
type StrategyBuilder = SparkSession => Strategy
6771
type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface
6872

73+
private[sql] def this(conf: SparkConf) {
74+
this()
75+
SparkSessionExtensions.applyExtensionsFromConf(conf, this)
76+
}
77+
6978
private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
7079

7180
/**
@@ -169,3 +178,29 @@ class SparkSessionExtensions {
169178
parserBuilders += builder
170179
}
171180
}
181+
182+
object SparkSessionExtensions extends Logging {
183+
184+
/**
185+
* Initialize extensions if the user has defined a configurator class in their SparkConf.
186+
* This class will be applied to the extensions passed into this function.
187+
*/
188+
private[sql] def applyExtensionsFromConf(conf: SparkConf, extensions: SparkSessionExtensions) {
189+
val extensionConfOption = conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
190+
if (extensionConfOption.isDefined) {
191+
val extensionConfClassName = extensionConfOption.get
192+
try {
193+
val extensionConfClass = Utils.classForName(extensionConfClassName)
194+
val extensionConf = extensionConfClass.newInstance()
195+
.asInstanceOf[SparkSessionExtensions => Unit]
196+
extensionConf(extensions)
197+
} catch {
198+
// Ignore the error if we cannot find the class or when the class has the wrong type.
199+
case e@(_: ClassCastException |
200+
_: ClassNotFoundException |
201+
_: NoClassDefFoundError) =>
202+
logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
203+
}
204+
}
205+
}
206+
}

0 commit comments

Comments
 (0)