Skip to content

Commit 18fc8e8

Browse files
ulysses-youdongjoon-hyun
authored andcommitted
[SPARK-39915][SQL][3.3] Dataset.repartition(N) may not create N partitions Non-AQE part
### What changes were proposed in this pull request? backport #37706 for branch-3.3 Skip optimize the root user-specified repartition in `PropagateEmptyRelation`. ### Why are the changes needed? Spark should preserve the final repatition which can affect the final output partition which is user-specified. For example: ```scala spark.sql("select * from values(1) where 1 < rand()").repartition(1) // before: == Optimized Logical Plan == LocalTableScan <empty>, [col1#0] // after: == Optimized Logical Plan == Repartition 1, true +- LocalRelation <empty>, [col1#0] ``` ### Does this PR introduce _any_ user-facing change? yes, the empty plan may change ### How was this patch tested? add test Closes #37730 from ulysses-you/empty-3.3. Authored-by: ulysses-you <ulyssesyou18@gmail.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent 4f69c98 commit 18fc8e8

File tree

5 files changed

+88
-4
lines changed

5 files changed

+88
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,9 @@ package object dsl {
491491
def repartition(num: Integer): LogicalPlan =
492492
Repartition(num, shuffle = true, logicalPlan)
493493

494+
def repartition(): LogicalPlan =
495+
RepartitionByExpression(Seq.empty, logicalPlan, None)
496+
494497
def distribute(exprs: Expression*)(n: Int): LogicalPlan =
495498
RepartitionByExpression(exprs, logicalPlan, numPartitions = n)
496499

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
2323
import org.apache.spark.sql.catalyst.plans._
2424
import org.apache.spark.sql.catalyst.plans.logical._
2525
import org.apache.spark.sql.catalyst.rules._
26+
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
2627
import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_FALSE_LITERAL}
2728

2829
/**
@@ -44,6 +45,9 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_
4445
* - Generate(Explode) with all empty children. Others like Hive UDTF may return results.
4546
*/
4647
abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSupport {
48+
// This tag is used to mark a repartition as a root repartition which is user-specified
49+
private[sql] val ROOT_REPARTITION = TreeNodeTag[Unit]("ROOT_REPARTITION")
50+
4751
protected def isEmpty(plan: LogicalPlan): Boolean = plan match {
4852
case p: LocalRelation => p.data.isEmpty
4953
case _ => false
@@ -136,8 +140,13 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup
136140
case _: Sort => empty(p)
137141
case _: GlobalLimit if !p.isStreaming => empty(p)
138142
case _: LocalLimit if !p.isStreaming => empty(p)
139-
case _: Repartition => empty(p)
140-
case _: RepartitionByExpression => empty(p)
143+
case _: RepartitionOperation =>
144+
if (p.getTagValue(ROOT_REPARTITION).isEmpty) {
145+
empty(p)
146+
} else {
147+
p.unsetTagValue(ROOT_REPARTITION)
148+
p
149+
}
141150
case _: RebalancePartitions => empty(p)
142151
// An aggregate with non-empty group expression will return one output row per group when the
143152
// input to the aggregate is not empty. If the input to the aggregate is empty then all groups
@@ -160,13 +169,40 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup
160169
case _ => p
161170
}
162171
}
172+
173+
protected def userSpecifiedRepartition(p: LogicalPlan): Boolean = p match {
174+
case _: Repartition => true
175+
case r: RepartitionByExpression
176+
if r.optNumPartitions.isDefined || r.partitionExpressions.nonEmpty => true
177+
case _ => false
178+
}
179+
180+
protected def applyInternal(plan: LogicalPlan): LogicalPlan
181+
182+
/**
183+
* Add a [[ROOT_REPARTITION]] tag for the root user-specified repartition so this rule can
184+
* skip optimize it.
185+
*/
186+
private def addTagForRootRepartition(plan: LogicalPlan): LogicalPlan = plan match {
187+
case p: Project => p.mapChildren(addTagForRootRepartition)
188+
case f: Filter => f.mapChildren(addTagForRootRepartition)
189+
case r if userSpecifiedRepartition(r) =>
190+
r.setTagValue(ROOT_REPARTITION, ())
191+
r
192+
case _ => plan
193+
}
194+
195+
override def apply(plan: LogicalPlan): LogicalPlan = {
196+
val planWithTag = addTagForRootRepartition(plan)
197+
applyInternal(planWithTag)
198+
}
163199
}
164200

165201
/**
166202
* This rule runs in the normal optimizer
167203
*/
168204
object PropagateEmptyRelation extends PropagateEmptyRelationBase {
169-
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
205+
override protected def applyInternal(p: LogicalPlan): LogicalPlan = p.transformUpWithPruning(
170206
_.containsAnyPattern(LOCAL_RELATION, TRUE_OR_FALSE_LITERAL), ruleId) {
171207
commonApplyFunc
172208
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,4 +309,42 @@ class PropagateEmptyRelationSuite extends PlanTest {
309309
val optimized2 = Optimize.execute(plan2)
310310
comparePlans(optimized2, expected)
311311
}
312+
313+
test("Propagate empty relation with repartition") {
314+
val emptyRelation = LocalRelation($"a".int, $"b".int)
315+
comparePlans(Optimize.execute(
316+
emptyRelation.repartition(1).sortBy($"a".asc).analyze
317+
), emptyRelation.analyze)
318+
319+
comparePlans(Optimize.execute(
320+
emptyRelation.distribute($"a")(1).sortBy($"a".asc).analyze
321+
), emptyRelation.analyze)
322+
323+
comparePlans(Optimize.execute(
324+
emptyRelation.repartition().analyze
325+
), emptyRelation.analyze)
326+
327+
comparePlans(Optimize.execute(
328+
emptyRelation.repartition(1).sortBy($"a".asc).repartition().analyze
329+
), emptyRelation.analyze)
330+
}
331+
332+
test("SPARK-39915: Dataset.repartition(N) may not create N partitions") {
333+
val emptyRelation = LocalRelation($"a".int, $"b".int)
334+
val p1 = emptyRelation.repartition(1).analyze
335+
comparePlans(Optimize.execute(p1), p1)
336+
337+
val p2 = emptyRelation.repartition(1).select($"a").analyze
338+
comparePlans(Optimize.execute(p2), p2)
339+
340+
val p3 = emptyRelation.repartition(1).where($"a" > rand(1)).analyze
341+
comparePlans(Optimize.execute(p3), p3)
342+
343+
val p4 = emptyRelation.repartition(1).where($"a" > rand(1)).select($"a").analyze
344+
comparePlans(Optimize.execute(p4), p4)
345+
346+
val p5 = emptyRelation.sortBy("$a".asc).repartition().limit(1).repartition(1).analyze
347+
val expected5 = emptyRelation.repartition(1).analyze
348+
comparePlans(Optimize.execute(p5), expected5)
349+
}
312350
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase {
6969
empty(j)
7070
}
7171

72-
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
72+
override protected def applyInternal(p: LogicalPlan): LogicalPlan = p.transformUpWithPruning(
7373
// LOCAL_RELATION and TRUE_OR_FALSE_LITERAL pattern are matched at
7474
// `PropagateEmptyRelationBase.commonApplyFunc`
7575
// LOGICAL_QUERY_STAGE pattern is matched at `PropagateEmptyRelationBase.commonApplyFunc`

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3281,6 +3281,13 @@ class DataFrameSuite extends QueryTest
32813281
Row(java.sql.Date.valueOf("2020-02-01"), java.sql.Date.valueOf("2020-02-01")) ::
32823282
Row(java.sql.Date.valueOf("2020-01-01"), java.sql.Date.valueOf("2020-01-02")) :: Nil)
32833283
}
3284+
3285+
test("SPARK-39915: Dataset.repartition(N) may not create N partitions") {
3286+
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
3287+
val df = spark.sql("select * from values(1) where 1 < rand()").repartition(2)
3288+
assert(df.queryExecution.executedPlan.execute().getNumPartitions == 2)
3289+
}
3290+
}
32843291
}
32853292

32863293
case class GroupByKey(a: Int, b: Int)

0 commit comments

Comments
 (0)