Skip to content

Commit ff7ab34

Browse files
ulysses-youcloud-fan
authored andcommitted
[SPARK-39915][SQL] Dataset.repartition(N) may not create N partitions Non-AQE part
### What changes were proposed in this pull request? Skip optimize the root user-specified repartition in `PropagateEmptyRelation`. ### Why are the changes needed? Spark should preserve the final repatition which can affect the final output partition which is user-specified. For example: ```scala spark.sql("select * from values(1) where 1 < rand()").repartition(1) // before: == Optimized Logical Plan == LocalTableScan <empty>, [col1#0] // after: == Optimized Logical Plan == Repartition 1, true +- LocalRelation <empty>, [col1#0] ``` ### Does this PR introduce _any_ user-facing change? yes, the empty plan may change ### How was this patch tested? add test Closes #37706 from ulysses-you/empty. Authored-by: ulysses-you <ulyssesyou18@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent e0cb2eb commit ff7ab34

File tree

5 files changed

+88
-4
lines changed

5 files changed

+88
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,9 @@ package object dsl {
501501
def repartition(num: Integer): LogicalPlan =
502502
Repartition(num, shuffle = true, logicalPlan)
503503

504+
def repartition(): LogicalPlan =
505+
RepartitionByExpression(Seq.empty, logicalPlan, None)
506+
504507
def distribute(exprs: Expression*)(n: Int): LogicalPlan =
505508
RepartitionByExpression(exprs, logicalPlan, numPartitions = n)
506509

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
2323
import org.apache.spark.sql.catalyst.plans._
2424
import org.apache.spark.sql.catalyst.plans.logical._
2525
import org.apache.spark.sql.catalyst.rules._
26+
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
2627
import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_FALSE_LITERAL}
2728

2829
/**
@@ -44,6 +45,9 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_
4445
* - Generate(Explode) with all empty children. Others like Hive UDTF may return results.
4546
*/
4647
abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSupport {
48+
// This tag is used to mark a repartition as a root repartition which is user-specified
49+
private[sql] val ROOT_REPARTITION = TreeNodeTag[Unit]("ROOT_REPARTITION")
50+
4751
protected def isEmpty(plan: LogicalPlan): Boolean = plan match {
4852
case p: LocalRelation => p.data.isEmpty
4953
case _ => false
@@ -137,8 +141,13 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup
137141
case _: GlobalLimit if !p.isStreaming => empty(p)
138142
case _: LocalLimit if !p.isStreaming => empty(p)
139143
case _: Offset => empty(p)
140-
case _: Repartition => empty(p)
141-
case _: RepartitionByExpression => empty(p)
144+
case _: RepartitionOperation =>
145+
if (p.getTagValue(ROOT_REPARTITION).isEmpty) {
146+
empty(p)
147+
} else {
148+
p.unsetTagValue(ROOT_REPARTITION)
149+
p
150+
}
142151
case _: RebalancePartitions => empty(p)
143152
// An aggregate with non-empty group expression will return one output row per group when the
144153
// input to the aggregate is not empty. If the input to the aggregate is empty then all groups
@@ -162,13 +171,40 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup
162171
case _ => p
163172
}
164173
}
174+
175+
protected def userSpecifiedRepartition(p: LogicalPlan): Boolean = p match {
176+
case _: Repartition => true
177+
case r: RepartitionByExpression
178+
if r.optNumPartitions.isDefined || r.partitionExpressions.nonEmpty => true
179+
case _ => false
180+
}
181+
182+
protected def applyInternal(plan: LogicalPlan): LogicalPlan
183+
184+
/**
185+
* Add a [[ROOT_REPARTITION]] tag for the root user-specified repartition so this rule can
186+
* skip optimize it.
187+
*/
188+
private def addTagForRootRepartition(plan: LogicalPlan): LogicalPlan = plan match {
189+
case p: Project => p.mapChildren(addTagForRootRepartition)
190+
case f: Filter => f.mapChildren(addTagForRootRepartition)
191+
case r if userSpecifiedRepartition(r) =>
192+
r.setTagValue(ROOT_REPARTITION, ())
193+
r
194+
case _ => plan
195+
}
196+
197+
override def apply(plan: LogicalPlan): LogicalPlan = {
198+
val planWithTag = addTagForRootRepartition(plan)
199+
applyInternal(planWithTag)
200+
}
165201
}
166202

167203
/**
168204
* This rule runs in the normal optimizer
169205
*/
170206
object PropagateEmptyRelation extends PropagateEmptyRelationBase {
171-
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
207+
override protected def applyInternal(p: LogicalPlan): LogicalPlan = p.transformUpWithPruning(
172208
_.containsAnyPattern(LOCAL_RELATION, TRUE_OR_FALSE_LITERAL), ruleId) {
173209
commonApplyFunc
174210
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,4 +327,42 @@ class PropagateEmptyRelationSuite extends PlanTest {
327327
.fromExternalRows(Seq($"a".int, $"b".int, $"window".long.withNullability(false)), Nil)
328328
comparePlans(Optimize.execute(originalQuery.analyze), expected.analyze)
329329
}
330+
331+
test("Propagate empty relation with repartition") {
332+
val emptyRelation = LocalRelation($"a".int, $"b".int)
333+
comparePlans(Optimize.execute(
334+
emptyRelation.repartition(1).sortBy($"a".asc).analyze
335+
), emptyRelation.analyze)
336+
337+
comparePlans(Optimize.execute(
338+
emptyRelation.distribute($"a")(1).sortBy($"a".asc).analyze
339+
), emptyRelation.analyze)
340+
341+
comparePlans(Optimize.execute(
342+
emptyRelation.repartition().analyze
343+
), emptyRelation.analyze)
344+
345+
comparePlans(Optimize.execute(
346+
emptyRelation.repartition(1).sortBy($"a".asc).repartition().analyze
347+
), emptyRelation.analyze)
348+
}
349+
350+
test("SPARK-39915: Dataset.repartition(N) may not create N partitions") {
351+
val emptyRelation = LocalRelation($"a".int, $"b".int)
352+
val p1 = emptyRelation.repartition(1).analyze
353+
comparePlans(Optimize.execute(p1), p1)
354+
355+
val p2 = emptyRelation.repartition(1).select($"a").analyze
356+
comparePlans(Optimize.execute(p2), p2)
357+
358+
val p3 = emptyRelation.repartition(1).where($"a" > rand(1)).analyze
359+
comparePlans(Optimize.execute(p3), p3)
360+
361+
val p4 = emptyRelation.repartition(1).where($"a" > rand(1)).select($"a").analyze
362+
comparePlans(Optimize.execute(p4), p4)
363+
364+
val p5 = emptyRelation.sortBy("$a".asc).repartition().limit(1).repartition(1).analyze
365+
val expected5 = emptyRelation.repartition(1).analyze
366+
comparePlans(Optimize.execute(p5), expected5)
367+
}
330368
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase {
6969
empty(j)
7070
}
7171

72-
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
72+
override protected def applyInternal(p: LogicalPlan): LogicalPlan = p.transformUpWithPruning(
7373
// LOCAL_RELATION and TRUE_OR_FALSE_LITERAL pattern are matched at
7474
// `PropagateEmptyRelationBase.commonApplyFunc`
7575
// LOGICAL_QUERY_STAGE pattern is matched at `PropagateEmptyRelationBase.commonApplyFunc`

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3419,6 +3419,13 @@ class DataFrameSuite extends QueryTest
34193419
Row(java.sql.Date.valueOf("2020-02-01"), java.sql.Date.valueOf("2020-02-01")) ::
34203420
Row(java.sql.Date.valueOf("2020-01-01"), java.sql.Date.valueOf("2020-01-02")) :: Nil)
34213421
}
3422+
3423+
test("SPARK-39915: Dataset.repartition(N) may not create N partitions") {
3424+
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
3425+
val df = spark.sql("select * from values(1) where 1 < rand()").repartition(2)
3426+
assert(df.queryExecution.executedPlan.execute().getNumPartitions == 2)
3427+
}
3428+
}
34223429
}
34233430

34243431
case class GroupByKey(a: Int, b: Int)

0 commit comments

Comments
 (0)