Skip to content

Commit bf6e45a

Browse files
fix
1 parent ab0bad9 commit bf6e45a

File tree

3 files changed

+64
-22
lines changed

3 files changed

+64
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,9 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
10411041
* Note that changes in the final output ordering may affect the file size (SPARK-32318).
10421042
* This rule handles the following cases:
10431043
* 1) if the sort order is empty or the sort order does not have any reference
1044-
* 2) if the child is already sorted
1044+
* 2) if the Sort operator is a local sort and the child is already sorted, or
1045+
* the Sort operator is a global sort with the child being another global Sort operator or
1046+
* a Range operator that satisfies the parent sort orders.
10451047
* 3) if there is another Sort operator separated by 0...n Project, Filter, Repartition or
10461048
* RepartitionByExpression (with deterministic expressions) operators
10471049
* 4) if the Sort operator is within Join separated by 0...n Project, Filter, Repartition or
@@ -1056,8 +1058,14 @@ object EliminateSorts extends Rule[LogicalPlan] {
10561058
case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
10571059
val newOrders = orders.filterNot(_.child.foldable)
10581060
if (newOrders.isEmpty) child else s.copy(order = newOrders)
1059-
case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
1060-
child
1061+
case s @ Sort(orders, global, child)
1062+
if SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
1063+
(global, child) match {
1064+
case (false, _) => child
1065+
case (true, r: Range) => r
1066+
case (true, s @ Sort(_, true, _)) => s
1067+
case (true, _) => s.copy(child = recursiveRemoveSort(child))
1068+
}
10611069
case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child))
10621070
case j @ Join(originLeft, originRight, _, cond, _) if cond.forall(_.deterministic) =>
10631071
j.copy(left = recursiveRemoveSort(originLeft), right = recursiveRemoveSort(originRight))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,22 @@ class EliminateSortsSuite extends PlanTest {
9797
comparePlans(optimized, correctAnswer)
9898
}
9999

100-
test("remove redundant order by") {
100+
test("remove redundant local sort") {
101101
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
102-
val unnecessaryReordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst)
102+
val unnecessaryReordered = orderedPlan.limit(2).select('a).sortBy('a.asc, 'b.desc_nullsFirst)
103103
val optimized = Optimize.execute(unnecessaryReordered.analyze)
104104
val correctAnswer = orderedPlan.limit(2).select('a).analyze
105105
comparePlans(Optimize.execute(optimized), correctAnswer)
106106
}
107107

108+
test("should not remove global sort") {
109+
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
110+
val reordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst)
111+
val optimized = Optimize.execute(reordered.analyze)
112+
val correctAnswer = reordered.analyze
113+
comparePlans(Optimize.execute(optimized), correctAnswer)
114+
}
115+
108116
test("do not remove sort if the order is different") {
109117
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
110118
val reorderedDifferently = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc)
@@ -113,22 +121,39 @@ class EliminateSortsSuite extends PlanTest {
113121
comparePlans(optimized, correctAnswer)
114122
}
115123

116-
test("filters don't affect order") {
124+
test("filters don't affect order for local sort") {
117125
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
118-
val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)
126+
val filteredAndReordered = orderedPlan.where('a > Literal(10)).sortBy('a.asc, 'b.desc)
119127
val optimized = Optimize.execute(filteredAndReordered.analyze)
120128
val correctAnswer = orderedPlan.where('a > Literal(10)).analyze
121129
comparePlans(optimized, correctAnswer)
122130
}
123131

124-
test("limits don't affect order") {
132+
test("should keep global sort when child is a filter operator with the same ordering") {
133+
val projectPlan = testRelation.select('a, 'b)
134+
val orderedPlan = projectPlan.orderBy('a.asc, 'b.desc)
135+
val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)
136+
val optimized = Optimize.execute(filteredAndReordered.analyze)
137+
val correctAnswer = projectPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc).analyze
138+
comparePlans(optimized, correctAnswer)
139+
}
140+
141+
test("limits don't affect order for local sort") {
125142
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
126-
val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc)
143+
val filteredAndReordered = orderedPlan.limit(Literal(10)).sortBy('a.asc, 'b.desc)
127144
val optimized = Optimize.execute(filteredAndReordered.analyze)
128145
val correctAnswer = orderedPlan.limit(Literal(10)).analyze
129146
comparePlans(optimized, correctAnswer)
130147
}
131148

149+
test("should keep global sort when child is a limit operator with the same ordering") {
150+
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
151+
val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc)
152+
val optimized = Optimize.execute(filteredAndReordered.analyze)
153+
val correctAnswer = filteredAndReordered.analyze
154+
comparePlans(optimized, correctAnswer)
155+
}
156+
132157
test("different sorts are not simplified if limit is in between") {
133158
val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10))
134159
.orderBy('a.asc)
@@ -331,4 +356,26 @@ class EliminateSortsSuite extends PlanTest {
331356
val correctAnswer = PushDownOptimizer.execute(noOrderByPlan.analyze)
332357
comparePlans(optimized, correctAnswer)
333358
}
359+
360+
test("remove two consecutive global sorts with same ordering") {
361+
Seq(
362+
(testRelation.orderBy('a.asc).orderBy('a.asc), testRelation.orderBy('a.asc)),
363+
(testRelation.orderBy('a.asc, 'b.desc).orderBy('a.asc),
364+
testRelation.orderBy('a.asc, 'b.desc))
365+
).foreach { case (ordered, answer) =>
366+
val optimized = Optimize.execute(ordered.analyze)
367+
comparePlans(optimized, answer.analyze)
368+
}
369+
}
370+
371+
test("should keep global sort when child is a local sort with the same ordering") {
372+
val correctAnswer = testRelation.orderBy('a.asc).analyze
373+
Seq(
374+
testRelation.sortBy('a.asc).orderBy('a.asc),
375+
testRelation.orderBy('a.asc).sortBy('a.asc).orderBy('a.asc)
376+
).foreach { ordered =>
377+
val optimized = Optimize.execute(ordered.analyze)
378+
comparePlans(optimized, correctAnswer)
379+
}
380+
}
334381
}

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -234,19 +234,6 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
234234
}
235235
}
236236

237-
test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") {
238-
val query = testData.select('key, 'value).sort('key.desc).cache()
239-
assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation])
240-
val resorted = query.sort('key.desc)
241-
assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty)
242-
assert(resorted.select('key).collect().map(_.getInt(0)).toSeq ==
243-
(1 to 100).reverse)
244-
// with a different order, the sort is needed
245-
val sortedAsc = query.sort('key)
246-
assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.size == 1)
247-
assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100))
248-
}
249-
250237
test("PartitioningCollection") {
251238
withTempView("normal", "small", "tiny") {
252239
testData.createOrReplaceTempView("normal")

0 commit comments

Comments
 (0)