Skip to content

Commit d912920

Browse files
committed
[SPARK-39180][SQL] Simplify the planning of limit and offset
### What changes were proposed in this pull request? This PR simplifies the planning of limit and offset: 1. Unify the semantics of physical plans that need to deal with limit + offset. These physical plans always do limit first, then offset. The planner rule should set limit and offset properly, for different plans, such as limit + offset and offset + limit. 2. Refactor the planner rule `SpecialLimit` to reuse the code of planning `TakeOrderedAndProjectExec`. 3. Let `GlobalLimitExec` to handle offset as well, so that we can remove `GlobalLimitAndOffsetExec`. This matches `CollectLimitExec`. ### Why are the changes needed? code simplification ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests Closes #36541 from cloud-fan/offset. Lead-authored-by: Wenchen Fan <cloud0fan@gmail.com> Co-authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 2e73d82 commit d912920

File tree

3 files changed

+124
-128
lines changed

3 files changed

+124
-128
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,12 +1303,25 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPr
13031303
copy(child = newChild)
13041304
}
13051305

1306+
object OffsetAndLimit {
1307+
def unapply(p: GlobalLimit): Option[(Int, Int, LogicalPlan)] = {
1308+
p match {
1309+
// Optimizer pushes local limit through offset, so we need to match the plan this way.
1310+
case GlobalLimit(IntegerLiteral(globalLimit),
1311+
Offset(IntegerLiteral(offset),
1312+
LocalLimit(IntegerLiteral(localLimit), child)))
1313+
if globalLimit + offset == localLimit =>
1314+
Some((offset, globalLimit, child))
1315+
case _ => None
1316+
}
1317+
}
1318+
}
1319+
13061320
object LimitAndOffset {
1307-
def unapply(p: GlobalLimit): Option[(Expression, Expression, LogicalPlan)] = {
1321+
def unapply(p: Offset): Option[(Int, Int, LogicalPlan)] = {
13081322
p match {
1309-
case GlobalLimit(le1, Offset(le2, LocalLimit(le3, child))) if le1.eval().asInstanceOf[Int]
1310-
+ le2.eval().asInstanceOf[Int] == le3.eval().asInstanceOf[Int] =>
1311-
Some((le1, le2, child))
1323+
case Offset(IntegerLiteral(offset), Limit(IntegerLiteral(limit), child)) =>
1324+
Some((limit, offset, child))
13121325
case _ => None
13131326
}
13141327
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -81,55 +81,56 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
8181
*/
8282
object SpecialLimits extends Strategy {
8383
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
84-
case ReturnAnswer(rootPlan) => rootPlan match {
85-
case Limit(IntegerLiteral(limit), Sort(order, true, child))
86-
if limit < conf.topKSortFallbackThreshold =>
87-
TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
88-
case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child)))
89-
if limit < conf.topKSortFallbackThreshold =>
90-
TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
84+
// Call `planTakeOrdered` first which matches a larger plan.
85+
case ReturnAnswer(rootPlan) => planTakeOrdered(rootPlan).getOrElse(rootPlan match {
86+
// We should match the combination of limit and offset first, to get the optimal physical
87+
// plan, instead of planning limit and offset separately.
88+
case LimitAndOffset(limit, offset, child) =>
89+
CollectLimitExec(limit = limit, child = planLater(child), offset = offset)
90+
case OffsetAndLimit(offset, limit, child) =>
91+
// 'Offset a' then 'Limit b' is the same as 'Limit a + b' then 'Offset a'.
92+
CollectLimitExec(limit = offset + limit, child = planLater(child), offset = offset)
9193
case Limit(IntegerLiteral(limit), child) =>
92-
CollectLimitExec(limit, planLater(child)) :: Nil
93-
case LimitAndOffset(IntegerLiteral(limit), IntegerLiteral(offset),
94-
Sort(order, true, child)) if limit + offset < conf.topKSortFallbackThreshold =>
95-
TakeOrderedAndProjectExec(
96-
limit, order, child.output, planLater(child), offset) :: Nil
97-
case LimitAndOffset(IntegerLiteral(limit), IntegerLiteral(offset),
98-
Project(projectList, Sort(order, true, child)))
99-
if limit + offset < conf.topKSortFallbackThreshold =>
100-
TakeOrderedAndProjectExec(
101-
limit, order, projectList, planLater(child), offset) :: Nil
102-
case LimitAndOffset(IntegerLiteral(limit), IntegerLiteral(offset), child) =>
103-
CollectLimitExec(limit, planLater(child), offset) :: Nil
94+
CollectLimitExec(limit = limit, child = planLater(child))
10495
case logical.Offset(IntegerLiteral(offset), child) =>
105-
CollectLimitExec(child = planLater(child), offset = offset) :: Nil
96+
CollectLimitExec(child = planLater(child), offset = offset)
10697
case Tail(IntegerLiteral(limit), child) =>
107-
CollectTailExec(limit, planLater(child)) :: Nil
108-
case other => planLater(other) :: Nil
109-
}
98+
CollectTailExec(limit, planLater(child))
99+
case other => planLater(other)
100+
}) :: Nil
101+
102+
case other => planTakeOrdered(other).toSeq
103+
}
104+
105+
private def planTakeOrdered(plan: LogicalPlan): Option[SparkPlan] = plan match {
106+
// We should match the combination of limit and offset first, to get the optimal physical
107+
// plan, instead of planning limit and offset separately.
108+
case LimitAndOffset(limit, offset, Sort(order, true, child))
109+
if limit < conf.topKSortFallbackThreshold =>
110+
Some(TakeOrderedAndProjectExec(
111+
limit, order, child.output, planLater(child), offset))
112+
case LimitAndOffset(limit, offset, Project(projectList, Sort(order, true, child)))
113+
if limit < conf.topKSortFallbackThreshold =>
114+
Some(TakeOrderedAndProjectExec(
115+
limit, order, projectList, planLater(child), offset))
116+
// 'Offset a' then 'Limit b' is the same as 'Limit a + b' then 'Offset a'.
117+
case OffsetAndLimit(offset, limit, Sort(order, true, child))
118+
if offset + limit < conf.topKSortFallbackThreshold =>
119+
Some(TakeOrderedAndProjectExec(
120+
offset + limit, order, child.output, planLater(child), offset))
121+
case OffsetAndLimit(offset, limit, Project(projectList, Sort(order, true, child)))
122+
if offset + limit < conf.topKSortFallbackThreshold =>
123+
Some(TakeOrderedAndProjectExec(
124+
offset + limit, order, projectList, planLater(child), offset))
110125
case Limit(IntegerLiteral(limit), Sort(order, true, child))
111126
if limit < conf.topKSortFallbackThreshold =>
112-
TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
127+
Some(TakeOrderedAndProjectExec(
128+
limit, order, child.output, planLater(child)))
113129
case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child)))
114130
if limit < conf.topKSortFallbackThreshold =>
115-
TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
116-
// This is a global LIMIT and OFFSET over a logical sorting operator,
117-
// where the sum of specified limit and specified offset is less than a heuristic threshold.
118-
// In this case we generate a physical top-K sorting operator, passing down
119-
// the limit and offset values to be evaluated inline during the physical
120-
// sorting operation for greater efficiency.
121-
case LimitAndOffset(IntegerLiteral(limit), IntegerLiteral(offset),
122-
Sort(order, true, child)) if limit + offset < conf.topKSortFallbackThreshold =>
123-
TakeOrderedAndProjectExec(
124-
limit, order, child.output, planLater(child), offset) :: Nil
125-
case LimitAndOffset(IntegerLiteral(limit), IntegerLiteral(offset),
126-
Project(projectList, Sort(order, true, child)))
127-
if limit + offset < conf.topKSortFallbackThreshold =>
128-
TakeOrderedAndProjectExec(limit, order, projectList, planLater(child), offset) :: Nil
129-
case LimitAndOffset(IntegerLiteral(limit), IntegerLiteral(offset), child) =>
130-
GlobalLimitAndOffsetExec(limit, offset, planLater(child)) :: Nil
131-
case _ =>
132-
Nil
131+
Some(TakeOrderedAndProjectExec(
132+
limit, order, projectList, planLater(child)))
133+
case _ => None
133134
}
134135
}
135136

@@ -814,12 +815,19 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
814815
case logical.LocalRelation(output, data, _) =>
815816
LocalTableScanExec(output, data) :: Nil
816817
case CommandResult(output, _, plan, data) => CommandResultExec(output, plan, data) :: Nil
818+
// We should match the combination of limit and offset first, to get the optimal physical
819+
// plan, instead of planning limit and offset separately.
820+
case LimitAndOffset(limit, offset, child) =>
821+
GlobalLimitExec(limit, planLater(child), offset) :: Nil
822+
case OffsetAndLimit(offset, limit, child) =>
823+
// 'Offset a' then 'Limit b' is the same as 'Limit a + b' then 'Offset a'.
824+
GlobalLimitExec(limit = offset + limit, child = planLater(child), offset = offset) :: Nil
817825
case logical.LocalLimit(IntegerLiteral(limit), child) =>
818826
execution.LocalLimitExec(limit, planLater(child)) :: Nil
819827
case logical.GlobalLimit(IntegerLiteral(limit), child) =>
820828
execution.GlobalLimitExec(limit, planLater(child)) :: Nil
821829
case logical.Offset(IntegerLiteral(offset), child) =>
822-
GlobalLimitAndOffsetExec(offset = offset, child = planLater(child)) :: Nil
830+
GlobalLimitExec(child = planLater(child), offset = offset) :: Nil
823831
case union: logical.Union =>
824832
execution.UnionExec(union.children.map(planLater)) :: Nil
825833
case g @ logical.Generate(generator, _, outer, _, _, child) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala

Lines changed: 56 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ trait LimitExec extends UnaryExecNode {
3737
}
3838

3939
/**
40-
* Take the first `limit` + `offset` elements and collect them to a single partition and then to
41-
* drop the first `offset` elements.
40+
* Take the first `limit` elements, collect them to a single partition and then to drop the
41+
* first `offset` elements.
4242
*
43-
* This operator will be used when a logical `Limit` operation is the final operator in an
44-
* logical plan, which happens when the user is collecting results back to the driver.
43+
* This operator will be used when a logical `Limit` and/or `Offset` operation is the final operator
44+
* in an logical plan, which happens when the user is collecting results back to the driver.
4545
*/
4646
case class CollectLimitExec(limit: Int = -1, child: SparkPlan, offset: Int = 0) extends LimitExec {
4747
assert(limit >= 0 || (limit == -1 && offset > 0))
@@ -56,7 +56,7 @@ case class CollectLimitExec(limit: Int = -1, child: SparkPlan, offset: Int = 0)
5656
// Then [1, 2, 3] will be taken and output [3].
5757
if (limit >= 0) {
5858
if (offset > 0) {
59-
child.executeTake(limit + offset).drop(offset)
59+
child.executeTake(limit).drop(offset)
6060
} else {
6161
child.executeTake(limit)
6262
}
@@ -79,11 +79,7 @@ case class CollectLimitExec(limit: Int = -1, child: SparkPlan, offset: Int = 0)
7979
childRDD
8080
} else {
8181
val locallyLimited = if (limit >= 0) {
82-
if (offset > 0) {
83-
childRDD.mapPartitionsInternal(_.take(limit + offset))
84-
} else {
85-
childRDD.mapPartitionsInternal(_.take(limit))
86-
}
82+
childRDD.mapPartitionsInternal(_.take(limit))
8783
} else {
8884
childRDD
8985
}
@@ -98,7 +94,7 @@ case class CollectLimitExec(limit: Int = -1, child: SparkPlan, offset: Int = 0)
9894
}
9995
if (limit >= 0) {
10096
if (offset > 0) {
101-
singlePartitionRDD.mapPartitionsInternal(_.slice(offset, offset + limit))
97+
singlePartitionRDD.mapPartitionsInternal(_.slice(offset, limit))
10298
} else {
10399
singlePartitionRDD.mapPartitionsInternal(_.take(limit))
104100
}
@@ -164,8 +160,8 @@ trait BaseLimitExec extends LimitExec with CodegenSupport {
164160

165161
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
166162

167-
protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
168-
iter.take(limit)
163+
protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitionsInternal {
164+
iter => iter.take(limit)
169165
}
170166

171167
override def inputRDDs(): Seq[RDD[InternalRow]] = {
@@ -215,61 +211,52 @@ case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec {
215211
}
216212

217213
/**
218-
* Take the first `limit` elements of the child's single output partition.
214+
* Take the first `limit` elements and then drop the first `offset` elements in the child's single
215+
* output partition.
219216
*/
220-
case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec {
221-
222-
override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil
223-
224-
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
225-
copy(child = newChild)
226-
}
227-
228-
/**
229-
* Skip the first `offset` elements then take the first `limit` of the following elements in
230-
* the child's single output partition.
231-
*/
232-
case class GlobalLimitAndOffsetExec(
233-
limit: Int = -1,
234-
offset: Int,
235-
child: SparkPlan) extends BaseLimitExec {
236-
assert(offset > 0)
217+
case class GlobalLimitExec(limit: Int = -1, child: SparkPlan, offset: Int = 0)
218+
extends BaseLimitExec {
219+
assert(limit >= 0 || (limit == -1 && offset > 0))
237220

238221
override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil
239222

240-
override def doExecute(): RDD[InternalRow] = if (limit >= 0) {
241-
child.execute().mapPartitionsInternal(iter => iter.slice(offset, limit + offset))
242-
} else {
243-
child.execute().mapPartitionsInternal(iter => iter.drop(offset))
223+
override def doExecute(): RDD[InternalRow] = {
224+
if (offset > 0) {
225+
if (limit >= 0) {
226+
child.execute().mapPartitionsInternal(iter => iter.slice(offset, limit))
227+
} else {
228+
child.execute().mapPartitionsInternal(iter => iter.drop(offset))
229+
}
230+
} else {
231+
super.doExecute()
232+
}
244233
}
245234

246-
private lazy val skipTerm = BaseLimitExec.newLimitCountTerm()
247-
248235
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
249-
ctx.addMutableState(
250-
CodeGenerator.JAVA_INT, skipTerm, forceInline = true, useFreshName = false)
251-
if (limit >= 0) {
252-
// The counter name is already obtained by the upstream operators via `limitNotReachedChecks`.
253-
// Here we have to inline it to not change its name. This is fine as we won't have many limit
254-
// operators in one query.
255-
ctx.addMutableState(
256-
CodeGenerator.JAVA_INT, countTerm, forceInline = true, useFreshName = false)
257-
s"""
258-
| if ($skipTerm < $offset) {
259-
| $skipTerm += 1;
260-
| } else if ($countTerm < $limit) {
261-
| $countTerm += 1;
262-
| ${consume(ctx, input)}
263-
| }
236+
if (offset > 0) {
237+
val skipTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "rowsSkipped", forceInline = true)
238+
if (limit > 0) {
239+
// In codegen, we skip the first `offset` rows, then take the first `limit - offset` rows.
240+
val finalLimit = limit - offset
241+
s"""
242+
| if ($skipTerm < $offset) {
243+
| $skipTerm += 1;
244+
| } else if ($countTerm < $finalLimit) {
245+
| $countTerm += 1;
246+
| ${consume(ctx, input)}
247+
| }
264248
""".stripMargin
249+
} else {
250+
s"""
251+
| if ($skipTerm < $offset) {
252+
| $skipTerm += 1;
253+
| } else {
254+
| ${consume(ctx, input)}
255+
| }
256+
""".stripMargin
257+
}
265258
} else {
266-
s"""
267-
| if ($skipTerm < $offset) {
268-
| $skipTerm += 1;
269-
| } else {
270-
| ${consume(ctx, input)}
271-
| }
272-
""".stripMargin
259+
super.doConsume(ctx, input, row)
273260
}
274261
}
275262

@@ -278,9 +265,9 @@ case class GlobalLimitAndOffsetExec(
278265
}
279266

280267
/**
281-
* Take the first limit elements as defined by the sortOrder, and do projection if needed.
282-
* This is logically equivalent to having a Limit operator after a [[SortExec]] operator,
283-
* or having a [[ProjectExec]] operator between them.
268+
* Take the first `limit` elements as defined by the sortOrder, then drop the first `offset`
269+
* elements, and do projection if needed. This is logically equivalent to having a Limit and/or
270+
* Offset operator after a [[SortExec]] operator, or having a [[ProjectExec]] operator between them.
284271
* This could have been named TopK, but Spark's top operator does the opposite in ordering
285272
* so we name it TakeOrdered to avoid confusion.
286273
*/
@@ -297,12 +284,8 @@ case class TakeOrderedAndProjectExec(
297284

298285
override def executeCollect(): Array[InternalRow] = {
299286
val ord = new LazilyGeneratedOrdering(sortOrder, child.output)
300-
val data = if (offset > 0) {
301-
child.execute().mapPartitionsInternal(_.map(_.copy()))
302-
.takeOrdered(limit + offset)(ord).drop(offset)
303-
} else {
304-
child.execute().mapPartitionsInternal(_.map(_.copy())).takeOrdered(limit)(ord)
305-
}
287+
val limited = child.execute().mapPartitionsInternal(_.map(_.copy())).takeOrdered(limit)(ord)
288+
val data = if (offset > 0) limited.drop(offset) else limited
306289
if (projectList != child.output) {
307290
val proj = UnsafeProjection.create(projectList, child.output)
308291
data.map(r => proj(r).copy())
@@ -328,15 +311,10 @@ case class TakeOrderedAndProjectExec(
328311
val singlePartitionRDD = if (childRDD.getNumPartitions == 1) {
329312
childRDD
330313
} else {
331-
val localTopK = if (offset > 0) {
332-
childRDD.mapPartitionsInternal { iter =>
333-
Utils.takeOrdered(iter.map(_.copy()), limit + offset)(ord)
334-
}
335-
} else {
336-
childRDD.mapPartitionsInternal { iter =>
337-
Utils.takeOrdered(iter.map(_.copy()), limit)(ord)
338-
}
314+
val localTopK = childRDD.mapPartitionsInternal { iter =>
315+
Utils.takeOrdered(iter.map(_.copy()), limit)(ord)
339316
}
317+
340318
new ShuffledRowRDD(
341319
ShuffleExchangeExec.prepareShuffleDependency(
342320
localTopK,
@@ -347,11 +325,8 @@ case class TakeOrderedAndProjectExec(
347325
readMetrics)
348326
}
349327
singlePartitionRDD.mapPartitionsInternal { iter =>
350-
val topK = if (offset > 0) {
351-
Utils.takeOrdered(iter.map(_.copy()), limit + offset)(ord).drop(offset)
352-
} else {
353-
Utils.takeOrdered(iter.map(_.copy()), limit)(ord)
354-
}
328+
val limited = Utils.takeOrdered(iter.map(_.copy()), limit)(ord)
329+
val topK = if (offset > 0) limited.drop(offset) else limited
355330
if (projectList != child.output) {
356331
val proj = UnsafeProjection.create(projectList, child.output)
357332
topK.map(r => proj(r))

0 commit comments

Comments
 (0)