Skip to content

[SPARK-18127] Add hooks and extension points to Spark #17724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
abstract class AbstractSqlParser extends ParserInterface with Logging {

/** Creates/Resolves DataType for a given SQL string. */
def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
// TODO add this to the parser interface.
override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
astBuilder.visitSingleDataType(parser.singleDataType())
}

Expand All @@ -50,8 +49,10 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
}

/** Creates FunctionIdentifier for a given SQL string. */
def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = parse(sqlText) { parser =>
astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier())
override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
parse(sqlText) { parser =>
astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier())
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,51 @@

package org.apache.spark.sql.catalyst.parser

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructType}

/**
* Interface for a parser.
*/
@DeveloperApi
trait ParserInterface {
/** Creates LogicalPlan for a given SQL string. */
/**
* Parse a string to a [[LogicalPlan]].
*/
@throws[ParseException]("Text cannot be parsed to a LogicalPlan")
def parsePlan(sqlText: String): LogicalPlan

/** Creates Expression for a given SQL string. */
/**
* Parse a string to an [[Expression]].
*/
@throws[ParseException]("Text cannot be parsed to an Expression")
def parseExpression(sqlText: String): Expression

/** Creates TableIdentifier for a given SQL string. */
/**
* Parse a string to a [[TableIdentifier]].
*/
@throws[ParseException]("Text cannot be parsed to a TableIdentifier")
def parseTableIdentifier(sqlText: String): TableIdentifier

/** Creates FunctionIdentifier for a given SQL string. */
/**
* Parse a string to a [[FunctionIdentifier]].
*/
@throws[ParseException]("Text cannot be parsed to a FunctionIdentifier")
def parseFunctionIdentifier(sqlText: String): FunctionIdentifier

/**
* Creates StructType for a given SQL string, which is a comma separated list of field
* definitions which will preserve the correct Hive metadata.
* Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list
* of field definitions which will preserve the correct Hive metadata.
*/
@throws[ParseException]("Text cannot be parsed to a schema")
def parseTableSchema(sqlText: String): StructType

/**
* Parse a string to a [[DataType]].
*/
@throws[ParseException]("Text cannot be parsed to a DataType")
def parseDataType(sqlText: String): DataType
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,10 @@ object StaticSQLConf {
"SQL configuration and the current database.")
.booleanConf
.createWithDefault(false)

val SPARK_SESSION_EXTENSIONS = buildStaticConf("spark.sql.extensions")
.doc("Name of the class used to configure Spark Session extensions. The class should " +
"implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.")
.stringConf
.createOptional
}
45 changes: 39 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.ui.SQLListener
import org.apache.spark.sql.internal.{BaseSessionStateBuilder, CatalogImpl, SessionState, SessionStateBuilder, SharedState}
import org.apache.spark.sql.internal._
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
Expand Down Expand Up @@ -77,11 +77,12 @@ import org.apache.spark.util.Utils
class SparkSession private(
@transient val sparkContext: SparkContext,
@transient private val existingSharedState: Option[SharedState],
@transient private val parentSessionState: Option[SessionState])
@transient private val parentSessionState: Option[SessionState],
@transient private[sql] val extensions: SparkSessionExtensions)
extends Serializable with Closeable with Logging { self =>

private[sql] def this(sc: SparkContext) {
this(sc, None, None)
this(sc, None, None, new SparkSessionExtensions)
}

sparkContext.assertNotStopped()
Expand Down Expand Up @@ -219,7 +220,7 @@ class SparkSession private(
* @since 2.0.0
*/
def newSession(): SparkSession = {
new SparkSession(sparkContext, Some(sharedState), parentSessionState = None)
new SparkSession(sparkContext, Some(sharedState), parentSessionState = None, extensions)
}

/**
Expand All @@ -235,7 +236,7 @@ class SparkSession private(
* implementation is Hive, this will initialize the metastore, which may take some time.
*/
private[sql] def cloneSession(): SparkSession = {
val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState))
val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState), extensions)
result.sessionState // force copy of SessionState
result
}
Expand Down Expand Up @@ -754,6 +755,8 @@ object SparkSession {

private[this] val options = new scala.collection.mutable.HashMap[String, String]

private[this] val extensions = new SparkSessionExtensions

private[this] var userSuppliedContext: Option[SparkContext] = None

private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized {
Expand Down Expand Up @@ -847,6 +850,17 @@ object SparkSession {
}
}

/**
* Inject extensions into the [[SparkSession]]. This allows a user to add Analyzer rules,
* Optimizer rules, Planning Strategies or a customized parser.
*
* @since 2.2.0
*/
def withExtensions(f: SparkSessionExtensions => Unit): Builder = {
f(extensions)
this
}

/**
* Gets an existing [[SparkSession]] or, if there is no existing one, creates a new
* one based on the options set in this builder.
Expand Down Expand Up @@ -903,7 +917,26 @@ object SparkSession {
}
sc
}
session = new SparkSession(sparkContext)

// Initialize extensions if the user has defined a configurator class.
val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
if (extensionConfOption.isDefined) {
val extensionConfClassName = extensionConfOption.get
try {
val extensionConfClass = Utils.classForName(extensionConfClassName)
val extensionConf = extensionConfClass.newInstance()
.asInstanceOf[SparkSessionExtensions => Unit]
extensionConf(extensions)
} catch {
// Ignore the error if we cannot find the class or when the class has the wrong type.
case e @ (_: ClassCastException |
_: ClassNotFoundException |
_: NoClassDefFoundError) =>
logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
}
}

session = new SparkSession(sparkContext, None, None, extensions)
options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) }
defaultSession.set(session)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import scala.collection.mutable

import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule

/**
* :: Experimental ::
* Holder for injection points to the [[SparkSession]]. We make NO guarantee about the stability
* regarding binary compatibility and source compatibility of methods here.
*
* This current provides the following extension points:
* - Analyzer Rules.
* - Check Analysis Rules
* - Optimizer Rules.
* - Planning Strategies.
* - Customized Parser.
* - (External) Catalog listeners.
*
* The extensions can be used by calling withExtension on the [[SparkSession.Builder]], for
* example:
* {{{
* SparkSession.builder()
* .master("...")
* .conf("...", true)
* .withExtensions { extensions =>
* extensions.injectResolutionRule { session =>
* ...
* }
* extensions.injectParser { (session, parser) =>
* ...
* }
* }
* .getOrCreate()
* }}}
*
* Note that none of the injected builders should assume that the [[SparkSession]] is fully
* initialized and should not touch the session's internals (e.g. the SessionState).
*/
@DeveloperApi
@Experimental
@InterfaceStability.Unstable
class SparkSessionExtensions {
type RuleBuilder = SparkSession => Rule[LogicalPlan]
type CheckRuleBuilder = SparkSession => LogicalPlan => Unit
type StrategyBuilder = SparkSession => Strategy
type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface

private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]

/**
* Build the analyzer resolution `Rule`s using the given [[SparkSession]].
*/
private[sql] def buildResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
resolutionRuleBuilders.map(_.apply(session))
}

/**
* Inject an analyzer resolution `Rule` builder into the [[SparkSession]]. These analyzer
* rules will be executed as part of the resolution phase of analysis.
*/
def injectResolutionRule(builder: RuleBuilder): Unit = {
resolutionRuleBuilders += builder
}

private[this] val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]

/**
* Build the analyzer post-hoc resolution `Rule`s using the given [[SparkSession]].
*/
private[sql] def buildPostHocResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
postHocResolutionRuleBuilders.map(_.apply(session))
}

/**
* Inject an analyzer `Rule` builder into the [[SparkSession]]. These analyzer
* rules will be executed after resolution.
*/
def injectPostHocResolutionRule(builder: RuleBuilder): Unit = {
postHocResolutionRuleBuilders += builder
}

private[this] val checkRuleBuilders = mutable.Buffer.empty[CheckRuleBuilder]

/**
* Build the check analysis `Rule`s using the given [[SparkSession]].
*/
private[sql] def buildCheckRules(session: SparkSession): Seq[LogicalPlan => Unit] = {
checkRuleBuilders.map(_.apply(session))
}

/**
* Inject an check analysis `Rule` builder into the [[SparkSession]]. The injected rules will
* be executed after the analysis phase. A check analysis rule is used to detect problems with a
* LogicalPlan and should throw an exception when a problem is found.
*/
def injectCheckRule(builder: CheckRuleBuilder): Unit = {
checkRuleBuilders += builder
}

private[this] val optimizerRules = mutable.Buffer.empty[RuleBuilder]

private[sql] def buildOptimizerRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
optimizerRules.map(_.apply(session))
}

/**
* Inject an optimizer `Rule` builder into the [[SparkSession]]. The injected rules will be
* executed during the operator optimization batch. An optimizer rule is used to improve the
* quality of an analyzed logical plan; these rules should never modify the result of the
* LogicalPlan.
*/
def injectOptimizerRule(builder: RuleBuilder): Unit = {
optimizerRules += builder
}

private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder]

private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = {
plannerStrategyBuilders.map(_.apply(session))
}

/**
* Inject a planner `Strategy` builder into the [[SparkSession]]. The injected strategy will
* be used to convert a `LogicalPlan` into a executable
* [[org.apache.spark.sql.execution.SparkPlan]].
*/
def injectPlannerStrategy(builder: StrategyBuilder): Unit = {
plannerStrategyBuilders += builder
}

private[this] val parserBuilders = mutable.Buffer.empty[ParserBuilder]

private[sql] def buildParser(
session: SparkSession,
initial: ParserInterface): ParserInterface = {
parserBuilders.foldLeft(initial) { (parser, builder) =>
builder(session, parser)
}
}

/**
* Inject a custom parser into the [[SparkSession]]. Note that the builder is passed a session
* and an initial parser. The latter allows for a user to create a partial parser and to delegate
* to the underlying parser for completeness. If a user injects more parsers, then the parsers
* are stacked on top of each other.
*/
def injectParser(builder: ParserBuilder): Unit = {
parserBuilders += builder
}
}
Loading