Skip to content

[SPARK-23375][SQL] Eliminate unneeded Sort in Optimizer #20560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
operatorOptimizationBatch) :+
Batch("Join Reorder", Once,
CostBasedJoinReorder) :+
Batch("Remove Redundant Sorts", Once,
RemoveRedundantSorts) :+
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) :+
Batch("Object Expressions Optimization", fixedPoint,
Expand Down Expand Up @@ -733,6 +735,16 @@ object EliminateSorts extends Rule[LogicalPlan] {
}
}

/**
* Removes Sort operation if the child is already sorted
*/
object RemoveRedundantSorts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
child
}
}

/**
* Removes filters that can be evaluated trivially. This can be done through the following ways:
* 1) by eliding the filter for cases where it will always evaluate to `true`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ abstract class LogicalPlan
* Refreshes (or invalidates) any metadata/data cached in the plan recursively.
*/
def refresh(): Unit = children.foreach(_.refresh())

/**
* Returns the output ordering that this plan generates.
*/
def outputOrdering: Seq[SortOrder] = Nil
}

/**
Expand Down Expand Up @@ -274,3 +279,7 @@ abstract class BinaryNode extends LogicalPlan {

override final def children: Seq[LogicalPlan] = Seq(left, right)
}

abstract class OrderPreservingUnaryNode extends UnaryNode {
override final def outputOrdering: Seq[SortOrder] = child.outputOrdering
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode {
* This node is inserted at the top of a subquery when it is optimized. This makes sure we can
* recognize a subquery as such, and it allows us to write subquery aware transformations.
*/
case class Subquery(child: LogicalPlan) extends UnaryNode {
case class Subquery(child: LogicalPlan) extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = child.output
}

case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like ProjectExec.outputOrdering, we can propagate ordering for aliased attributes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I don't fully understand what you mean. In ProjectExec.outputOrdering we are getting the child.outputOrdering exactly as it is done here.

extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
override def maxRows: Option[Long] = child.maxRows

Expand Down Expand Up @@ -125,7 +126,7 @@ case class Generate(
}

case class Filter(condition: Expression, child: LogicalPlan)
extends UnaryNode with PredicateHelper {
extends OrderPreservingUnaryNode with PredicateHelper {
override def output: Seq[Attribute] = child.output

override def maxRows: Option[Long] = child.maxRows
Expand Down Expand Up @@ -469,6 +470,7 @@ case class Sort(
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def maxRows: Option[Long] = child.maxRows
override def outputOrdering: Seq[SortOrder] = order
}

/** Factory for constructing new `Range` nodes. */
Expand Down Expand Up @@ -522,6 +524,15 @@ case class Range(
override def computeStats(): Statistics = {
Statistics(sizeInBytes = LongType.defaultSize * numElements)
}

override def outputOrdering: Seq[SortOrder] = {
val order = if (step > 0) {
Ascending
} else {
Descending
}
output.map(a => SortOrder(a, order))
}
}

case class Aggregate(
Expand Down Expand Up @@ -728,7 +739,7 @@ object Limit {
*
* See [[Limit]] for more information.
*/
case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = child.output
override def maxRows: Option[Long] = {
limitExpr match {
Expand All @@ -744,7 +755,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
*
* See [[Limit]] for more information.
*/
case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = child.output

override def maxRowsPerPartition: Option[Long] = {
Expand All @@ -764,7 +775,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
case class SubqueryAlias(
alias: String,
child: LogicalPlan)
extends UnaryNode {
extends OrderPreservingUnaryNode {

override def doCanonicalize(): LogicalPlan = child.canonicalized

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL}

class RemoveRedundantSortsSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Remove Redundant Sorts", Once,
RemoveRedundantSorts) ::
Batch("Collapse Project", Once,
CollapseProject) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

test("remove redundant order by") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst)
val optimized = Optimize.execute(unnecessaryReordered.analyze)
val correctAnswer = orderedPlan.select('a).analyze
comparePlans(Optimize.execute(optimized), correctAnswer)
}

test("do not remove sort if the order is different") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc)
val optimized = Optimize.execute(reorderedDifferently.analyze)
val correctAnswer = reorderedDifferently.analyze
comparePlans(optimized, correctAnswer)
}

test("filters don't affect order") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)
val optimized = Optimize.execute(filteredAndReordered.analyze)
val correctAnswer = orderedPlan.where('a > Literal(10)).analyze
comparePlans(optimized, correctAnswer)
}

test("limits don't affect order") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc)
val optimized = Optimize.execute(filteredAndReordered.analyze)
val correctAnswer = orderedPlan.limit(Literal(10)).analyze
comparePlans(optimized, correctAnswer)
}

test("range is already sorted") {
val inputPlan = Range(1L, 1000L, 1, 10)
val orderedPlan = inputPlan.orderBy('id.asc)
val optimized = Optimize.execute(orderedPlan.analyze)
val correctAnswer = inputPlan.analyze
comparePlans(optimized, correctAnswer)

val reversedPlan = inputPlan.orderBy('id.desc)
val reversedOptimized = Optimize.execute(reversedPlan.analyze)
val reversedCorrectAnswer = reversedPlan.analyze
comparePlans(reversedOptimized, reversedCorrectAnswer)

val negativeStepInputPlan = Range(10L, 1L, -1, 10)
val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc)
val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze)
val negativeStepCorrectAnswer = negativeStepInputPlan.analyze
comparePlans(negativeStepOptimized, negativeStepCorrectAnswer)
}

test("sort should not be removed when there is a node which doesn't guarantee any order") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc)
val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc)
val optimized = Optimize.execute(groupedAndResorted.analyze)
val correctAnswer = groupedAndResorted.analyze
comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class CacheManager extends Logging {
sparkSession.sessionState.conf.columnBatchSize, storageLevel,
sparkSession.sessionState.executePlan(planToCache).executedPlan,
tableName,
planToCache.stats)
planToCache)
cachedData.add(CachedData(planToCache, inMemoryRelation))
}
}
Expand Down Expand Up @@ -148,7 +148,7 @@ class CacheManager extends Logging {
storageLevel = cd.cachedRepresentation.storageLevel,
child = spark.sessionState.executePlan(cd.plan).executedPlan,
tableName = cd.cachedRepresentation.tableName,
statsOfPlanToCache = cd.plan.stats)
logicalPlan = cd.plan)
needToRecache += cd.copy(cachedRepresentation = newCache)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ case class LogicalRDD(
output: Seq[Attribute],
rdd: RDD[InternalRow],
outputPartitioning: Partitioning = UnknownPartitioning(0),
outputOrdering: Seq[SortOrder] = Nil,
override val outputOrdering: Seq[SortOrder] = Nil,
override val isStreaming: Boolean = false)(session: SparkSession)
extends LeafNode with MultiInstanceRelation {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.LongAccumulator
Expand All @@ -39,9 +39,9 @@ object InMemoryRelation {
storageLevel: StorageLevel,
child: SparkPlan,
tableName: Option[String],
statsOfPlanToCache: Statistics): InMemoryRelation =
logicalPlan: LogicalPlan): InMemoryRelation =
new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)(
statsOfPlanToCache = statsOfPlanToCache)
statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering)
}


Expand All @@ -64,7 +64,8 @@ case class InMemoryRelation(
tableName: Option[String])(
@transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
statsOfPlanToCache: Statistics)
statsOfPlanToCache: Statistics,
override val outputOrdering: Seq[SortOrder])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be added to otherCopyArgs ; otherwise, we will lose it when doing the tree transformation. #22715 fixed it.

extends logical.LeafNode with MultiInstanceRelation {

override protected def innerChildren: Seq[SparkPlan] = Seq(child)
Expand All @@ -76,7 +77,8 @@ case class InMemoryRelation(
tableName = None)(
_cachedColumnBuffers,
sizeInBytesStats,
statsOfPlanToCache)
statsOfPlanToCache,
outputOrdering)

override def producedAttributes: AttributeSet = outputSet

Expand Down Expand Up @@ -159,7 +161,7 @@ case class InMemoryRelation(
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
InMemoryRelation(
newOutput, useCompression, batchSize, storageLevel, child, tableName)(
_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache, outputOrdering)
}

override def newInstance(): this.type = {
Expand All @@ -172,7 +174,8 @@ case class InMemoryRelation(
tableName)(
_cachedColumnBuffers,
sizeInBytesStats,
statsOfPlanToCache).asInstanceOf[this.type]
statsOfPlanToCache,
outputOrdering).asInstanceOf[this.type]
}

def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext {
def computeChiSquareTest(): Double = {
val n = 10000
// Trigger a sort
val data = spark.range(0, n, 1, 1).sort('id)
val data = spark.range(0, n, 1, 1).sort('id.desc)
.selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect()

// Compute histogram for the number of records per partition post sort
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.{execution, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition, Sort}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
Expand Down Expand Up @@ -197,6 +197,19 @@ class PlannerSuite extends SharedSQLContext {
assert(planned.child.isInstanceOf[CollectLimitExec])
}

test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") {
val query = testData.select('key, 'value).sort('key.desc).cache()
assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation])
val resorted = query.sort('key.desc)
assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty)
assert(resorted.select('key).collect().map(_.getInt(0)).toSeq ==
(1 to 100).reverse)
// with a different order, the sort is needed
val sortedAsc = query.sort('key)
assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.size == 1)
assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100))
}

test("PartitioningCollection") {
withTempView("normal", "small", "tiny") {
testData.createOrReplaceTempView("normal")
Expand Down
Loading