Skip to content

Commit 83f259f

Browse files
allisonwang-dbcloud-fan
authored andcommitted
[SPARK-33183][SQL][3.0] Fix Optimizer rule EliminateSorts and add a physical rule to remove redundant sorts
Backport #30093 for branch-3.0. I've updated the configuration version to 2.4.8. ### What changes were proposed in this pull request? This PR aims to fix a correctness bug in the optimizer rule EliminateSorts. It also adds a new physical rule to remove redundant sorts that cannot be eliminated in the Optimizer rule after the bugfix. ### Why are the changes needed? A global sort should not be eliminated even if its child is ordered since we don't know if its child ordering is global or local. For example, in the following scenario, the first sort shouldn't be removed because it has a stronger guarantee than the second sort even if the sort orders are the same for both sorts. ``` Sort(orders, global = True, ...) Sort(orders, global = False, ...) ``` Since there is no straightforward way to identify whether a node's output ordering is local or global, we should not remove a global sort even if its child is already ordered. ### Does this PR introduce any user-facing change? Yes ### How was this patch tested? Unit tests Closes #30195 from allisonwang-db/SPARK-33183-branch-3.0. Authored-by: allisonwang-db <66282705+allisonwang-db@users.noreply.github.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 8f57603 commit 83f259f

File tree

8 files changed

+267
-28
lines changed

8 files changed

+267
-28
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -971,20 +971,26 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
971971
/**
972972
* Removes Sort operation. This can happen:
973973
* 1) if the sort order is empty or the sort order does not have any reference
974-
* 2) if the child is already sorted
974+
* 2) if the Sort operator is a local sort and the child is already sorted
975975
* 3) if there is another Sort operator separated by 0...n Project/Filter operators
976976
* 4) if the Sort operator is within Join separated by 0...n Project/Filter operators only,
977977
* and the Join conditions is deterministic
978978
* 5) if the Sort operator is within GroupBy separated by 0...n Project/Filter operators only,
979979
* and the aggregate function is order irrelevant
980980
*/
981981
object EliminateSorts extends Rule[LogicalPlan] {
982-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
982+
def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally
983+
984+
private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
983985
case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
984986
val newOrders = orders.filterNot(_.child.foldable)
985-
if (newOrders.isEmpty) child else s.copy(order = newOrders)
986-
case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
987-
child
987+
if (newOrders.isEmpty) {
988+
applyLocally.lift(child).getOrElse(child)
989+
} else {
990+
s.copy(order = newOrders)
991+
}
992+
case Sort(orders, false, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
993+
applyLocally.lift(child).getOrElse(child)
988994
case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child))
989995
case j @ Join(originLeft, originRight, _, cond, _) if cond.forall(_.deterministic) =>
990996
j.copy(left = recursiveRemoveSort(originLeft), right = recursiveRemoveSort(originRight))

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,13 @@ object SQLConf {
12011201
.booleanConf
12021202
.createWithDefault(true)
12031203

1204+
val REMOVE_REDUNDANT_SORTS_ENABLED = buildConf("spark.sql.execution.removeRedundantSorts")
1205+
.internal()
1206+
.doc("Whether to remove redundant physical sort node")
1207+
.version("2.4.8")
1208+
.booleanConf
1209+
.createWithDefault(true)
1210+
12041211
val STATE_STORE_PROVIDER_CLASS =
12051212
buildConf("spark.sql.streaming.stateStore.providerClass")
12061213
.internal()

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala

Lines changed: 92 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,34 @@ class EliminateSortsSuite extends PlanTest {
9797
comparePlans(optimized, correctAnswer)
9898
}
9999

100-
test("remove redundant order by") {
100+
test("SPARK-33183: remove consecutive no-op sorts") {
101+
val plan = testRelation.orderBy().orderBy().orderBy()
102+
val optimized = Optimize.execute(plan.analyze)
103+
val correctAnswer = testRelation.analyze
104+
comparePlans(optimized, correctAnswer)
105+
}
106+
107+
test("SPARK-33183: remove redundant sort by") {
101108
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
102-
val unnecessaryReordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst)
109+
val unnecessaryReordered = orderedPlan.limit(2).select('a).sortBy('a.asc, 'b.desc_nullsFirst)
103110
val optimized = Optimize.execute(unnecessaryReordered.analyze)
104111
val correctAnswer = orderedPlan.limit(2).select('a).analyze
105-
comparePlans(Optimize.execute(optimized), correctAnswer)
112+
comparePlans(optimized, correctAnswer)
113+
}
114+
115+
test("SPARK-33183: remove all redundant local sorts") {
116+
val orderedPlan = testRelation.sortBy('a.asc).orderBy('a.asc).sortBy('a.asc)
117+
val optimized = Optimize.execute(orderedPlan.analyze)
118+
val correctAnswer = testRelation.orderBy('a.asc).analyze
119+
comparePlans(optimized, correctAnswer)
120+
}
121+
122+
test("SPARK-33183: should not remove global sort") {
123+
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
124+
val reordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst)
125+
val optimized = Optimize.execute(reordered.analyze)
126+
val correctAnswer = reordered.analyze
127+
comparePlans(optimized, correctAnswer)
106128
}
107129

108130
test("do not remove sort if the order is different") {
@@ -113,22 +135,39 @@ class EliminateSortsSuite extends PlanTest {
113135
comparePlans(optimized, correctAnswer)
114136
}
115137

116-
test("filters don't affect order") {
138+
test("SPARK-33183: remove top level local sort with filter operators") {
117139
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
118-
val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)
140+
val filteredAndReordered = orderedPlan.where('a > Literal(10)).sortBy('a.asc, 'b.desc)
119141
val optimized = Optimize.execute(filteredAndReordered.analyze)
120142
val correctAnswer = orderedPlan.where('a > Literal(10)).analyze
121143
comparePlans(optimized, correctAnswer)
122144
}
123145

124-
test("limits don't affect order") {
146+
test("SPARK-33183: keep top level global sort with filter operators") {
147+
val projectPlan = testRelation.select('a, 'b)
148+
val orderedPlan = projectPlan.orderBy('a.asc, 'b.desc)
149+
val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)
150+
val optimized = Optimize.execute(filteredAndReordered.analyze)
151+
val correctAnswer = projectPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc).analyze
152+
comparePlans(optimized, correctAnswer)
153+
}
154+
155+
test("SPARK-33183: limits should not affect order for local sort") {
125156
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
126-
val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc)
157+
val filteredAndReordered = orderedPlan.limit(Literal(10)).sortBy('a.asc, 'b.desc)
127158
val optimized = Optimize.execute(filteredAndReordered.analyze)
128159
val correctAnswer = orderedPlan.limit(Literal(10)).analyze
129160
comparePlans(optimized, correctAnswer)
130161
}
131162

163+
test("SPARK-33183: should not remove global sort with limit operators") {
164+
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
165+
val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc)
166+
val optimized = Optimize.execute(filteredAndReordered.analyze)
167+
val correctAnswer = filteredAndReordered.analyze
168+
comparePlans(optimized, correctAnswer)
169+
}
170+
132171
test("different sorts are not simplified if limit is in between") {
133172
val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10))
134173
.orderBy('a.asc)
@@ -137,11 +176,11 @@ class EliminateSortsSuite extends PlanTest {
137176
comparePlans(optimized, correctAnswer)
138177
}
139178

140-
test("range is already sorted") {
179+
test("SPARK-33183: should not remove global sort with range operator") {
141180
val inputPlan = Range(1L, 1000L, 1, 10)
142181
val orderedPlan = inputPlan.orderBy('id.asc)
143182
val optimized = Optimize.execute(orderedPlan.analyze)
144-
val correctAnswer = inputPlan.analyze
183+
val correctAnswer = orderedPlan.analyze
145184
comparePlans(optimized, correctAnswer)
146185

147186
val reversedPlan = inputPlan.orderBy('id.desc)
@@ -152,10 +191,18 @@ class EliminateSortsSuite extends PlanTest {
152191
val negativeStepInputPlan = Range(10L, 1L, -1, 10)
153192
val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc)
154193
val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze)
155-
val negativeStepCorrectAnswer = negativeStepInputPlan.analyze
194+
val negativeStepCorrectAnswer = negativeStepOrderedPlan.analyze
156195
comparePlans(negativeStepOptimized, negativeStepCorrectAnswer)
157196
}
158197

198+
test("SPARK-33183: remove local sort with range operator") {
199+
val inputPlan = Range(1L, 1000L, 1, 10)
200+
val orderedPlan = inputPlan.sortBy('id.asc)
201+
val optimized = Optimize.execute(orderedPlan.analyze)
202+
val correctAnswer = inputPlan.analyze
203+
comparePlans(optimized, correctAnswer)
204+
}
205+
159206
test("sort should not be removed when there is a node which doesn't guarantee any order") {
160207
val orderedPlan = testRelation.select('a, 'b)
161208
val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc)
@@ -319,4 +366,39 @@ class EliminateSortsSuite extends PlanTest {
319366
val correctAnswer = PushDownOptimizer.execute(noOrderByPlan.analyze)
320367
comparePlans(optimized, correctAnswer)
321368
}
369+
370+
test("SPARK-33183: remove consecutive global sorts with the same ordering") {
371+
Seq(
372+
(testRelation.orderBy('a.asc).orderBy('a.asc), testRelation.orderBy('a.asc)),
373+
(testRelation.orderBy('a.asc, 'b.desc).orderBy('a.asc), testRelation.orderBy('a.asc))
374+
).foreach { case (ordered, answer) =>
375+
val optimized = Optimize.execute(ordered.analyze)
376+
comparePlans(optimized, answer.analyze)
377+
}
378+
}
379+
380+
test("SPARK-33183: remove consecutive local sorts with the same ordering") {
381+
val orderedPlan = testRelation.sortBy('a.asc).sortBy('a.asc).sortBy('a.asc)
382+
val optimized = Optimize.execute(orderedPlan.analyze)
383+
val correctAnswer = testRelation.sortBy('a.asc).analyze
384+
comparePlans(optimized, correctAnswer)
385+
}
386+
387+
test("SPARK-33183: remove consecutive local sorts with different ordering") {
388+
val orderedPlan = testRelation.sortBy('b.asc).sortBy('a.desc).sortBy('a.asc)
389+
val optimized = Optimize.execute(orderedPlan.analyze)
390+
val correctAnswer = testRelation.sortBy('a.asc).analyze
391+
comparePlans(optimized, correctAnswer)
392+
}
393+
394+
test("SPARK-33183: should keep global sort when child is a local sort with the same ordering") {
395+
val correctAnswer = testRelation.orderBy('a.asc).analyze
396+
Seq(
397+
testRelation.sortBy('a.asc).orderBy('a.asc),
398+
testRelation.orderBy('a.asc).sortBy('a.asc).orderBy('a.asc)
399+
).foreach { ordered =>
400+
val optimized = Optimize.execute(ordered.analyze)
401+
comparePlans(optimized, correctAnswer)
402+
}
403+
}
322404
}

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ object QueryExecution {
297297
Seq(
298298
PlanDynamicPruningFilters(sparkSession),
299299
PlanSubqueries(sparkSession),
300+
RemoveRedundantSorts(sparkSession.sessionState.conf),
300301
EnsureRequirements(sparkSession.sessionState.conf),
301302
ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.conf,
302303
sparkSession.sessionState.columnarRules),
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import org.apache.spark.sql.catalyst.expressions.SortOrder
21+
import org.apache.spark.sql.catalyst.rules.Rule
22+
import org.apache.spark.sql.internal.SQLConf
23+
24+
/**
25+
* Remove redundant SortExec node from the spark plan. A sort node is redundant when
26+
* its child satisfies both its sort orders and its required child distribution. Note
27+
* this rule differs from the Optimizer rule EliminateSorts in that this rule also checks
28+
* if the child satisfies the required distribution so that it is safe to remove not only a
29+
* local sort but also a global sort when its child already satisfies required sort orders.
30+
*/
31+
case class RemoveRedundantSorts(conf: SQLConf) extends Rule[SparkPlan] {
32+
def apply(plan: SparkPlan): SparkPlan = {
33+
if (!conf.getConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED)) {
34+
plan
35+
} else {
36+
removeSorts(plan)
37+
}
38+
}
39+
40+
private def removeSorts(plan: SparkPlan): SparkPlan = plan transform {
41+
case s @ SortExec(orders, _, child, _)
42+
if SortOrder.orderingSatisfies(child.outputOrdering, orders) &&
43+
child.outputPartitioning.satisfies(s.requiredChildDistribution.head) =>
44+
child
45+
}
46+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,14 @@ case class AdaptiveSparkPlanExec(
8383
)
8484
}
8585

86+
@transient private val removeRedundantSorts = RemoveRedundantSorts(conf)
8687
@transient private val ensureRequirements = EnsureRequirements(conf)
8788

8889
// A list of physical plan rules to be applied before creation of query stages. The physical
8990
// plan should reach a final status of query stages (i.e., no more addition or removal of
9091
// Exchange nodes) after running these rules.
9192
private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq(
93+
removeRedundantSorts,
9294
ensureRequirements
9395
) ++ context.session.sessionState.queryStagePrepRules
9496

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -234,19 +234,6 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
234234
}
235235
}
236236

237-
test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") {
238-
val query = testData.select('key, 'value).sort('key.desc).cache()
239-
assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation])
240-
val resorted = query.sort('key.desc)
241-
assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty)
242-
assert(resorted.select('key).collect().map(_.getInt(0)).toSeq ==
243-
(1 to 100).reverse)
244-
// with a different order, the sort is needed
245-
val sortedAsc = query.sort('key)
246-
assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.size == 1)
247-
assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100))
248-
}
249-
250237
test("PartitioningCollection") {
251238
withTempView("normal", "small", "tiny") {
252239
testData.createOrReplaceTempView("normal")

0 commit comments

Comments
 (0)