Skip to content

[SPARK-22080][SQL] Adds support for allowing user to add pre-optimization rules #19295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ class ExperimentalMethods private[sql]() {
*/
@volatile var extraStrategies: Seq[Strategy] = Nil

@volatile var extraPreOptimizations: Seq[Rule[LogicalPlan]] = Nil

@volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about rename this extraPostOptimizations?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an API change. We can't do it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, i agree with @gatorsmile, renaming extraOptimizations to extraPostOptimizations will be symmetric with extraPreOptimizations, but doing so may affect the existing API calls.


override def clone(): ExperimentalMethods = {
val result = new ExperimentalMethods
result.extraStrategies = extraStrategies
result.extraPreOptimizations = extraPreOptimizations
result.extraOptimizations = extraOptimizations
result
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,18 @@ class SparkOptimizer(
experimentalMethods: ExperimentalMethods)
extends Optimizer(catalog) {

override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+
val experimentalPreOptimizations: Batch = Batch(
"User Provided Pre Optimizers", fixedPoint, experimentalMethods.extraPreOptimizations: _*)

val experimentalPostOptimizations: Batch = Batch(
"User Provided Post Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)

override def batches: Seq[Batch] =
((experimentalPreOptimizations +: preOptimizationBatches) ++ super.batches :+
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++
postHocOptimizationBatches :+
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
postHocOptimizationBatches :+ experimentalPostOptimizations

/**
* Optimization batches that are executed before the regular optimization batches (also before
Expand Down
16 changes: 13 additions & 3 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructT
@deprecated("This suite is deprecated to silent compiler deprecation warnings", "2.0.0")
class SQLContextSuite extends SparkFunSuite with SharedSparkContext {

object DummyRule extends Rule[LogicalPlan] {
object DummyPostOptimizationRule extends Rule[LogicalPlan] {
def apply(p: LogicalPlan): LogicalPlan = p
}

object DummyPreOptimizationRule extends Rule[LogicalPlan] {
def apply(p: LogicalPlan): LogicalPlan = p
}

Expand Down Expand Up @@ -78,8 +82,14 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext {

test("Catalyst optimization passes are modifiable at runtime") {
val sqlContext = SQLContext.getOrCreate(sc)
sqlContext.experimental.extraOptimizations = Seq(DummyRule)
assert(sqlContext.sessionState.optimizer.batches.flatMap(_.rules).contains(DummyRule))
sqlContext.experimental.extraOptimizations = Seq(DummyPostOptimizationRule)
sqlContext.experimental.extraPreOptimizations = Seq(DummyPreOptimizationRule)

val firstBatch = sqlContext.sessionState.optimizer.batches.head
val lastBatch = sqlContext.sessionState.optimizer.batches.last

assert(firstBatch.rules == Seq(DummyPreOptimizationRule))
assert(lastBatch.rules == Seq(DummyPostOptimizationRule))
}

test("get all tables") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class SessionStateSuite extends SparkFunSuite
}

test("fork new session and inherit experimental methods") {
val originalExtraOptimizations = activeSession.experimental.extraOptimizations
val originalExtraPostOptimizations = activeSession.experimental.extraOptimizations
val originalExtraPreOptimizations = activeSession.experimental.extraPreOptimizations
val originalExtraStrategies = activeSession.experimental.extraStrategies
try {
object DummyRule1 extends Rule[LogicalPlan] {
Expand All @@ -105,23 +106,35 @@ class SessionStateSuite extends SparkFunSuite
object DummyRule2 extends Rule[LogicalPlan] {
def apply(p: LogicalPlan): LogicalPlan = p
}
val optimizations = List(DummyRule1, DummyRule2)
activeSession.experimental.extraOptimizations = optimizations
object DummyRule3 extends Rule[LogicalPlan] {
def apply(p: LogicalPlan): LogicalPlan = p
}
val preOptimizations = List(DummyRule3)
val postOptimizations = List(DummyRule1, DummyRule2)
activeSession.experimental.extraPreOptimizations = preOptimizations
activeSession.experimental.extraOptimizations = postOptimizations
val forkedSession = activeSession.cloneSession()

// inheritance
assert(forkedSession ne activeSession)
assert(forkedSession.experimental ne activeSession.experimental)
assert(forkedSession.experimental.extraPreOptimizations.toSet ==
activeSession.experimental.extraPreOptimizations.toSet)
assert(forkedSession.experimental.extraOptimizations.toSet ==
activeSession.experimental.extraOptimizations.toSet)

// independence
forkedSession.experimental.extraPreOptimizations = List(DummyRule1)
forkedSession.experimental.extraOptimizations = List(DummyRule2)
assert(activeSession.experimental.extraOptimizations == optimizations)
assert(activeSession.experimental.extraPreOptimizations == preOptimizations)
assert(activeSession.experimental.extraOptimizations == postOptimizations)
activeSession.experimental.extraPreOptimizations = List(DummyRule3)
activeSession.experimental.extraOptimizations = List(DummyRule1)
assert(forkedSession.experimental.extraPreOptimizations == List(DummyRule1))
assert(forkedSession.experimental.extraOptimizations == List(DummyRule2))
} finally {
activeSession.experimental.extraOptimizations = originalExtraOptimizations
activeSession.experimental.extraPreOptimizations = originalExtraPreOptimizations
activeSession.experimental.extraOptimizations = originalExtraPostOptimizations
activeSession.experimental.extraStrategies = originalExtraStrategies
}
}
Expand Down