Skip to content

Commit bcbd9fa

Browse files
add physical rule
1 parent 2966802 commit bcbd9fa

File tree

7 files changed

+193
-2
lines changed

7 files changed

+193
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,8 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
10521052
* function is order irrelevant
10531053
*/
10541054
object EliminateSorts extends Rule[LogicalPlan] {
1055+
// transformUp is needed here to ensure idempotency of this rule when removing consecutive
1056+
// local sorts.
10551057
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
10561058
case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
10571059
val newOrders = orders.filterNot(_.child.foldable)

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,6 +1242,13 @@ object SQLConf {
12421242
.booleanConf
12431243
.createWithDefault(true)
12441244

1245+
val REMOVE_REDUNDANT_SORTS_ENABLED = buildConf("spark.sql.execution.removeRedundantSorts")
1246+
.internal()
1247+
.doc("Whether to remove redundant physical sort node")
1248+
.version("3.1.0")
1249+
.booleanConf
1250+
.createWithDefault(true)
1251+
12451252
val STATE_STORE_PROVIDER_CLASS =
12461253
buildConf("spark.sql.streaming.stateStore.providerClass")
12471254
.internal()

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,15 @@ class EliminateSortsSuite extends PlanTest {
128128
comparePlans(optimized, correctAnswer)
129129
}
130130

131-
test("SPARK-33183: filters should not affect order for local sort") {
131+
test("SPARK-33183: remove top level local sort with filter operators") {
132132
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
133133
val filteredAndReordered = orderedPlan.where('a > Literal(10)).sortBy('a.asc, 'b.desc)
134134
val optimized = Optimize.execute(filteredAndReordered.analyze)
135135
val correctAnswer = orderedPlan.where('a > Literal(10)).analyze
136136
comparePlans(optimized, correctAnswer)
137137
}
138138

139-
test("SPARK-33183: should not remove global sort with filter operators") {
139+
test("SPARK-33183: keep top level global sort with filter operators") {
140140
val projectPlan = testRelation.select('a, 'b)
141141
val orderedPlan = projectPlan.orderBy('a.asc, 'b.desc)
142142
val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ object QueryExecution {
343343
PlanDynamicPruningFilters(sparkSession),
344344
PlanSubqueries(sparkSession),
345345
RemoveRedundantProjects(sparkSession.sessionState.conf),
346+
RemoveRedundantSorts(sparkSession.sessionState.conf),
346347
EnsureRequirements(sparkSession.sessionState.conf),
347348
DisableUnnecessaryBucketedScan(sparkSession.sessionState.conf),
348349
ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.conf,
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import org.apache.spark.sql.catalyst.expressions.SortOrder
21+
import org.apache.spark.sql.catalyst.rules.Rule
22+
import org.apache.spark.sql.internal.SQLConf
23+
24+
/**
25+
* Remove redundant SortExec node from the spark plan. A sort node is redundant when
26+
* its child satisfies both its sort orders and its required child distribution.
27+
*/
28+
case class RemoveRedundantSorts(conf: SQLConf) extends Rule[SparkPlan] {
29+
def apply(plan: SparkPlan): SparkPlan = {
30+
if (!conf.getConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED)) {
31+
plan
32+
} else {
33+
removeSorts(plan)
34+
}
35+
}
36+
37+
private def removeSorts(plan: SparkPlan): SparkPlan = plan transform {
38+
case s @ SortExec(orders, _, child, _)
39+
if SortOrder.orderingSatisfies(child.outputOrdering, orders) &&
40+
child.outputPartitioning.satisfies(s.requiredChildDistribution.head) =>
41+
child
42+
}
43+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,15 @@ case class AdaptiveSparkPlanExec(
8383
@transient private val optimizer = new AQEOptimizer(conf)
8484

8585
@transient private val removeRedundantProjects = RemoveRedundantProjects(conf)
86+
@transient private val removeRedundantSorts = RemoveRedundantSorts(conf)
8687
@transient private val ensureRequirements = EnsureRequirements(conf)
8788

8889
// A list of physical plan rules to be applied before creation of query stages. The physical
8990
// plan should reach a final status of query stages (i.e., no more addition or removal of
9091
// Exchange nodes) after running these rules.
9192
private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq(
9293
removeRedundantProjects,
94+
removeRedundantSorts,
9395
ensureRequirements
9496
) ++ context.session.sessionState.queryStagePrepRules
9597

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import org.apache.spark.sql.{DataFrame, QueryTest}
21+
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
22+
import org.apache.spark.sql.internal.SQLConf
23+
import org.apache.spark.sql.test.SharedSparkSession
24+
25+
26+
abstract class RemoveRedundantSortsSuiteBase
27+
extends QueryTest
28+
with SharedSparkSession
29+
with AdaptiveSparkPlanHelper {
30+
import testImplicits._
31+
32+
private def checkNumSorts(df: DataFrame, count: Int): Unit = {
33+
val plan = df.queryExecution.executedPlan
34+
assert(collectWithSubqueries(plan) { case s: SortExec => s }.length == count)
35+
}
36+
37+
private def checkSorts(query: String, enabledCount: Int, disabledCount: Int): Unit = {
38+
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "true") {
39+
val df = sql(query)
40+
checkNumSorts(df, enabledCount)
41+
val result = df.collect()
42+
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "false") {
43+
val df = sql(query)
44+
checkNumSorts(df, disabledCount)
45+
checkAnswer(df, result)
46+
}
47+
}
48+
}
49+
50+
test("remove redundant sorts with limit") {
51+
withTempView("t") {
52+
spark.range(100).select('id as "key").createOrReplaceTempView("t")
53+
val query =
54+
"""
55+
|SELECT key FROM
56+
| (SELECT key FROM t WHERE key > 10 ORDER BY key DESC LIMIT 10)
57+
|ORDER BY key DESC
58+
|""".stripMargin
59+
checkSorts(query, 0, 1)
60+
}
61+
}
62+
63+
test("remove redundant sorts with broadcast hash join") {
64+
withTempView("t1", "t2") {
65+
spark.range(1000).select('id as "key").createOrReplaceTempView("t1")
66+
spark.range(1000).select('id as "key").createOrReplaceTempView("t2")
67+
val queryTemplate = """
68+
|SELECT t1.key FROM
69+
| (SELECT key FROM t1 WHERE key > 10 ORDER BY key DESC LIMIT 10) t1
70+
|%s
71+
| (SELECT key FROM t2 WHERE key > 50 ORDER BY key DESC LIMIT 100) t2
72+
|ON t1.key = t2.key
73+
|ORDER BY %s
74+
""".stripMargin
75+
76+
val innerJoinAsc = queryTemplate.format("JOIN", "t2.key ASC")
77+
checkSorts(innerJoinAsc, 1, 1)
78+
79+
val innerJoinDesc = queryTemplate.format("JOIN", "t2.key DESC")
80+
checkSorts(innerJoinDesc, 0, 1)
81+
82+
val innerJoinDesc1 = queryTemplate.format("JOIN", "t1.key DESC")
83+
checkSorts(innerJoinDesc1, 1, 1)
84+
85+
val leftOuterJoinDesc = queryTemplate.format("LEFT JOIN", "t1.key DESC")
86+
checkSorts(leftOuterJoinDesc, 0, 1)
87+
}
88+
}
89+
90+
test("remove redundant sorts with sort merge join") {
91+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
92+
withTempView("t1", "t2") {
93+
spark.range(1000).select('id as "key").createOrReplaceTempView("t1")
94+
spark.range(1000).select('id as "key").createOrReplaceTempView("t2")
95+
val query = """
96+
|SELECT t1.key FROM
97+
| (SELECT key FROM t1 WHERE key > 10 ORDER BY key DESC LIMIT 10) t1
98+
|JOIN
99+
| (SELECT key FROM t2 WHERE key > 50 ORDER BY key DESC LIMIT 100) t2
100+
|ON t1.key = t2.key
101+
|ORDER BY t1.key
102+
""".stripMargin
103+
104+
val queryAsc = query + " ASC"
105+
checkSorts(queryAsc, 2, 3)
106+
107+
// Top level sort should only be eliminated if it's order is descending with SMJ.
108+
val queryDesc = query + " DESC"
109+
checkSorts(queryDesc, 3, 3)
110+
}
111+
}
112+
}
113+
114+
test("cached sorted data doesn't need to be re-sorted") {
115+
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "true") {
116+
val df = spark.range(1000).select('id as "key").sort('key.desc).cache()
117+
val resorted = df.sort('key.desc)
118+
val sortedAsc = df.sort('key.asc)
119+
checkNumSorts(df, 0)
120+
checkNumSorts(resorted, 0)
121+
checkNumSorts(sortedAsc, 1)
122+
val result = resorted.collect()
123+
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "false") {
124+
val resorted = df.sort('key.desc)
125+
checkNumSorts(resorted, 1)
126+
checkAnswer(resorted, result)
127+
}
128+
}
129+
}
130+
}
131+
132+
class RemoveRedundantSortsSuite extends RemoveRedundantSortsSuiteBase
133+
with DisableAdaptiveExecutionSuite
134+
135+
class RemoveRedundantSortsSuiteAE extends RemoveRedundantSortsSuiteBase
136+
with EnableAdaptiveExecutionSuite

0 commit comments

Comments
 (0)