Skip to content

Commit f0de600

Browse files
sameeragarwalgatorsmile
authored andcommitted
[SPARK-18127] Add hooks and extension points to Spark
## What changes were proposed in this pull request? This patch adds support for customizing the spark session by injecting user-defined custom extensions. This allows a user to add custom analyzer rules/checks, optimizer rules, planning strategies or even a customized parser. ## How was this patch tested? Unit Tests in SparkSessionExtensionSuite Author: Sameer Agarwal <sameerag@cs.berkeley.edu> Closes #17724 from sameeragarwal/session-extensions. (cherry picked from commit caf3920) Signed-off-by: Xiao Li <gatorsmile@gmail.com>
1 parent f971ce5 commit f0de600

File tree

7 files changed

+418
-25
lines changed

7 files changed

+418
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
3434
abstract class AbstractSqlParser extends ParserInterface with Logging {
3535

3636
/** Creates/Resolves DataType for a given SQL string. */
37-
def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
38-
// TODO add this to the parser interface.
37+
override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
3938
astBuilder.visitSingleDataType(parser.singleDataType())
4039
}
4140

@@ -50,8 +49,10 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
5049
}
5150

5251
/** Creates FunctionIdentifier for a given SQL string. */
53-
def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = parse(sqlText) { parser =>
54-
astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier())
52+
override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
53+
parse(sqlText) { parser =>
54+
astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier())
55+
}
5556
}
5657

5758
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,51 @@
1717

1818
package org.apache.spark.sql.catalyst.parser
1919

20+
import org.apache.spark.annotation.DeveloperApi
2021
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
2122
import org.apache.spark.sql.catalyst.expressions.Expression
2223
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
23-
import org.apache.spark.sql.types.StructType
24+
import org.apache.spark.sql.types.{DataType, StructType}
2425

2526
/**
2627
* Interface for a parser.
2728
*/
29+
@DeveloperApi
2830
trait ParserInterface {
29-
/** Creates LogicalPlan for a given SQL string. */
31+
/**
32+
* Parse a string to a [[LogicalPlan]].
33+
*/
34+
@throws[ParseException]("Text cannot be parsed to a LogicalPlan")
3035
def parsePlan(sqlText: String): LogicalPlan
3136

32-
/** Creates Expression for a given SQL string. */
37+
/**
38+
* Parse a string to an [[Expression]].
39+
*/
40+
@throws[ParseException]("Text cannot be parsed to an Expression")
3341
def parseExpression(sqlText: String): Expression
3442

35-
/** Creates TableIdentifier for a given SQL string. */
43+
/**
44+
* Parse a string to a [[TableIdentifier]].
45+
*/
46+
@throws[ParseException]("Text cannot be parsed to a TableIdentifier")
3647
def parseTableIdentifier(sqlText: String): TableIdentifier
3748

38-
/** Creates FunctionIdentifier for a given SQL string. */
49+
/**
50+
* Parse a string to a [[FunctionIdentifier]].
51+
*/
52+
@throws[ParseException]("Text cannot be parsed to a FunctionIdentifier")
3953
def parseFunctionIdentifier(sqlText: String): FunctionIdentifier
4054

4155
/**
42-
* Creates StructType for a given SQL string, which is a comma separated list of field
43-
* definitions which will preserve the correct Hive metadata.
56+
* Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list
57+
* of field definitions which will preserve the correct Hive metadata.
4458
*/
59+
@throws[ParseException]("Text cannot be parsed to a schema")
4560
def parseTableSchema(sqlText: String): StructType
61+
62+
/**
63+
* Parse a string to a [[DataType]].
64+
*/
65+
@throws[ParseException]("Text cannot be parsed to a DataType")
66+
def parseDataType(sqlText: String): DataType
4667
}

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,10 @@ object StaticSQLConf {
8181
"SQL configuration and the current database.")
8282
.booleanConf
8383
.createWithDefault(false)
84+
85+
val SPARK_SESSION_EXTENSIONS = buildStaticConf("spark.sql.extensions")
86+
.doc("Name of the class used to configure Spark Session extensions. The class should " +
87+
"implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.")
88+
.stringConf
89+
.createOptional
8490
}

sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
3838
import org.apache.spark.sql.execution._
3939
import org.apache.spark.sql.execution.datasources.LogicalRelation
4040
import org.apache.spark.sql.execution.ui.SQLListener
41-
import org.apache.spark.sql.internal.{BaseSessionStateBuilder, CatalogImpl, SessionState, SessionStateBuilder, SharedState}
41+
import org.apache.spark.sql.internal._
4242
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
4343
import org.apache.spark.sql.sources.BaseRelation
4444
import org.apache.spark.sql.streaming._
@@ -77,11 +77,12 @@ import org.apache.spark.util.Utils
7777
class SparkSession private(
7878
@transient val sparkContext: SparkContext,
7979
@transient private val existingSharedState: Option[SharedState],
80-
@transient private val parentSessionState: Option[SessionState])
80+
@transient private val parentSessionState: Option[SessionState],
81+
@transient private[sql] val extensions: SparkSessionExtensions)
8182
extends Serializable with Closeable with Logging { self =>
8283

8384
private[sql] def this(sc: SparkContext) {
84-
this(sc, None, None)
85+
this(sc, None, None, new SparkSessionExtensions)
8586
}
8687

8788
sparkContext.assertNotStopped()
@@ -219,7 +220,7 @@ class SparkSession private(
219220
* @since 2.0.0
220221
*/
221222
def newSession(): SparkSession = {
222-
new SparkSession(sparkContext, Some(sharedState), parentSessionState = None)
223+
new SparkSession(sparkContext, Some(sharedState), parentSessionState = None, extensions)
223224
}
224225

225226
/**
@@ -235,7 +236,7 @@ class SparkSession private(
235236
* implementation is Hive, this will initialize the metastore, which may take some time.
236237
*/
237238
private[sql] def cloneSession(): SparkSession = {
238-
val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState))
239+
val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState), extensions)
239240
result.sessionState // force copy of SessionState
240241
result
241242
}
@@ -754,6 +755,8 @@ object SparkSession {
754755

755756
private[this] val options = new scala.collection.mutable.HashMap[String, String]
756757

758+
private[this] val extensions = new SparkSessionExtensions
759+
757760
private[this] var userSuppliedContext: Option[SparkContext] = None
758761

759762
private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized {
@@ -847,6 +850,17 @@ object SparkSession {
847850
}
848851
}
849852

853+
/**
854+
* Inject extensions into the [[SparkSession]]. This allows a user to add Analyzer rules,
855+
* Optimizer rules, Planning Strategies or a customized parser.
856+
*
857+
* @since 2.2.0
858+
*/
859+
def withExtensions(f: SparkSessionExtensions => Unit): Builder = {
860+
f(extensions)
861+
this
862+
}
863+
850864
/**
851865
* Gets an existing [[SparkSession]] or, if there is no existing one, creates a new
852866
* one based on the options set in this builder.
@@ -903,7 +917,26 @@ object SparkSession {
903917
}
904918
sc
905919
}
906-
session = new SparkSession(sparkContext)
920+
921+
// Initialize extensions if the user has defined a configurator class.
922+
val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
923+
if (extensionConfOption.isDefined) {
924+
val extensionConfClassName = extensionConfOption.get
925+
try {
926+
val extensionConfClass = Utils.classForName(extensionConfClassName)
927+
val extensionConf = extensionConfClass.newInstance()
928+
.asInstanceOf[SparkSessionExtensions => Unit]
929+
extensionConf(extensions)
930+
} catch {
931+
// Ignore the error if we cannot find the class or when the class has the wrong type.
932+
case e @ (_: ClassCastException |
933+
_: ClassNotFoundException |
934+
_: NoClassDefFoundError) =>
935+
logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
936+
}
937+
}
938+
939+
session = new SparkSession(sparkContext, None, None, extensions)
907940
options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) }
908941
defaultSession.set(session)
909942

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
23+
import org.apache.spark.sql.catalyst.parser.ParserInterface
24+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
25+
import org.apache.spark.sql.catalyst.rules.Rule
26+
27+
/**
28+
* :: Experimental ::
29+
* Holder for injection points to the [[SparkSession]]. We make NO guarantee about the stability
30+
* regarding binary compatibility and source compatibility of methods here.
31+
*
32+
* This current provides the following extension points:
33+
* - Analyzer Rules.
34+
* - Check Analysis Rules
35+
* - Optimizer Rules.
36+
* - Planning Strategies.
37+
* - Customized Parser.
38+
* - (External) Catalog listeners.
39+
*
40+
* The extensions can be used by calling withExtension on the [[SparkSession.Builder]], for
41+
* example:
42+
* {{{
43+
* SparkSession.builder()
44+
* .master("...")
45+
* .conf("...", true)
46+
* .withExtensions { extensions =>
47+
* extensions.injectResolutionRule { session =>
48+
* ...
49+
* }
50+
* extensions.injectParser { (session, parser) =>
51+
* ...
52+
* }
53+
* }
54+
* .getOrCreate()
55+
* }}}
56+
*
57+
* Note that none of the injected builders should assume that the [[SparkSession]] is fully
58+
* initialized and should not touch the session's internals (e.g. the SessionState).
59+
*/
60+
@DeveloperApi
61+
@Experimental
62+
@InterfaceStability.Unstable
63+
class SparkSessionExtensions {
64+
type RuleBuilder = SparkSession => Rule[LogicalPlan]
65+
type CheckRuleBuilder = SparkSession => LogicalPlan => Unit
66+
type StrategyBuilder = SparkSession => Strategy
67+
type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface
68+
69+
private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
70+
71+
/**
72+
* Build the analyzer resolution `Rule`s using the given [[SparkSession]].
73+
*/
74+
private[sql] def buildResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
75+
resolutionRuleBuilders.map(_.apply(session))
76+
}
77+
78+
/**
79+
* Inject an analyzer resolution `Rule` builder into the [[SparkSession]]. These analyzer
80+
* rules will be executed as part of the resolution phase of analysis.
81+
*/
82+
def injectResolutionRule(builder: RuleBuilder): Unit = {
83+
resolutionRuleBuilders += builder
84+
}
85+
86+
private[this] val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
87+
88+
/**
89+
* Build the analyzer post-hoc resolution `Rule`s using the given [[SparkSession]].
90+
*/
91+
private[sql] def buildPostHocResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
92+
postHocResolutionRuleBuilders.map(_.apply(session))
93+
}
94+
95+
/**
96+
* Inject an analyzer `Rule` builder into the [[SparkSession]]. These analyzer
97+
* rules will be executed after resolution.
98+
*/
99+
def injectPostHocResolutionRule(builder: RuleBuilder): Unit = {
100+
postHocResolutionRuleBuilders += builder
101+
}
102+
103+
private[this] val checkRuleBuilders = mutable.Buffer.empty[CheckRuleBuilder]
104+
105+
/**
106+
* Build the check analysis `Rule`s using the given [[SparkSession]].
107+
*/
108+
private[sql] def buildCheckRules(session: SparkSession): Seq[LogicalPlan => Unit] = {
109+
checkRuleBuilders.map(_.apply(session))
110+
}
111+
112+
/**
113+
* Inject an check analysis `Rule` builder into the [[SparkSession]]. The injected rules will
114+
* be executed after the analysis phase. A check analysis rule is used to detect problems with a
115+
* LogicalPlan and should throw an exception when a problem is found.
116+
*/
117+
def injectCheckRule(builder: CheckRuleBuilder): Unit = {
118+
checkRuleBuilders += builder
119+
}
120+
121+
private[this] val optimizerRules = mutable.Buffer.empty[RuleBuilder]
122+
123+
private[sql] def buildOptimizerRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
124+
optimizerRules.map(_.apply(session))
125+
}
126+
127+
/**
128+
* Inject an optimizer `Rule` builder into the [[SparkSession]]. The injected rules will be
129+
* executed during the operator optimization batch. An optimizer rule is used to improve the
130+
* quality of an analyzed logical plan; these rules should never modify the result of the
131+
* LogicalPlan.
132+
*/
133+
def injectOptimizerRule(builder: RuleBuilder): Unit = {
134+
optimizerRules += builder
135+
}
136+
137+
private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder]
138+
139+
private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = {
140+
plannerStrategyBuilders.map(_.apply(session))
141+
}
142+
143+
/**
144+
* Inject a planner `Strategy` builder into the [[SparkSession]]. The injected strategy will
145+
* be used to convert a `LogicalPlan` into a executable
146+
* [[org.apache.spark.sql.execution.SparkPlan]].
147+
*/
148+
def injectPlannerStrategy(builder: StrategyBuilder): Unit = {
149+
plannerStrategyBuilders += builder
150+
}
151+
152+
private[this] val parserBuilders = mutable.Buffer.empty[ParserBuilder]
153+
154+
private[sql] def buildParser(
155+
session: SparkSession,
156+
initial: ParserInterface): ParserInterface = {
157+
parserBuilders.foldLeft(initial) { (parser, builder) =>
158+
builder(session, parser)
159+
}
160+
}
161+
162+
/**
163+
* Inject a custom parser into the [[SparkSession]]. Note that the builder is passed a session
164+
* and an initial parser. The latter allows for a user to create a partial parser and to delegate
165+
* to the underlying parser for completeness. If a user injects more parsers, then the parsers
166+
* are stacked on top of each other.
167+
*/
168+
def injectParser(builder: ParserBuilder): Unit = {
169+
parserBuilders += builder
170+
}
171+
}

0 commit comments

Comments
 (0)