Skip to content

Commit 59fada7

Browse files
committed
Addressed feedback
1 parent a9c97ce commit 59fada7

File tree

3 files changed

+41
-16
lines changed

3 files changed

+41
-16
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,15 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
415415
val (h, t) = branches.span(_._1 != TrueLiteral)
416416
CaseWhen( h :+ t.head, None)
417417

418-
case CaseWhen((cond, branchValue) :: Nil, elseValue) =>
419-
If(cond, branchValue, elseValue.getOrElse(Literal(null, branchValue.dataType)))
418+
case CaseWhen(branches, elseValue) if branches.length == 1 =>
419+
// Using pattern matching like `CaseWhen((cond, branchValue) :: Nil, elseValue)` will not
420+
// work since the implementation of `branches` can be `ArrayBuffer`. A full test is in
421+
// "SPARK-24892: simplify `CaseWhen` to `If` when there is only one branch",
422+
// `SQLQuerySuite.scala`.
423+
val cond = branches.head._1
424+
val trueValue = branches.head._2
425+
val falseValue = elseValue.getOrElse(Literal(null, trueValue.dataType))
426+
If(cond, trueValue, falseValue)
420427
}
421428
}
422429
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
3939
}
4040

4141
private val trueBranch = (TrueLiteral, Literal(5))
42-
private val normalBranch = (NonFoldableLiteral(true), Literal(10))
42+
private val normalBranch1 = (NonFoldableLiteral(true), Literal(10))
43+
private val normalBranch2 = (NonFoldableLiteral(false), Literal(3))
4344
private val unreachableBranch = (FalseLiteral, Literal(20))
4445
private val nullBranch = (Literal.create(null, NullType), Literal(30))
4546

@@ -60,18 +61,23 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
6061
test("remove unreachable branches") {
6162
// i.e. removing branches whose conditions are always false
6263
assertEquivalent(
63-
CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
64-
If(normalBranch._1, normalBranch._2, Literal(null, normalBranch._2.dataType)))
64+
CaseWhen(unreachableBranch :: normalBranch1 :: unreachableBranch ::
65+
normalBranch2 :: nullBranch :: Nil, None),
66+
CaseWhen(normalBranch1 :: normalBranch2 :: Nil, None))
6567
}
6668

6769
test("simplify CaseWhen to If when there is only one branch") {
6870
assertEquivalent(
69-
CaseWhen(normalBranch :: Nil, None),
70-
If(normalBranch._1, normalBranch._2, Literal(null, normalBranch._2.dataType)))
71+
CaseWhen(normalBranch1 :: Nil, Some(Literal(30))),
72+
If(normalBranch1._1, normalBranch1._2, Literal(30)))
7173

7274
assertEquivalent(
73-
CaseWhen(normalBranch :: Nil, Some(Literal(30))),
74-
If(normalBranch._1, normalBranch._2, Literal(30)))
75+
CaseWhen(normalBranch1 :: Nil, None),
76+
If(normalBranch1._1, normalBranch1._2, Literal(null, normalBranch1._2.dataType)))
77+
78+
assertEquivalent(
79+
CaseWhen(unreachableBranch :: normalBranch1 :: unreachableBranch :: nullBranch :: Nil, None),
80+
If(normalBranch1._1, normalBranch1._2, Literal(null, normalBranch1._2.dataType)))
7581
}
7682

7783
test("remove entire CaseWhen if only the else branch is reachable") {
@@ -86,28 +92,28 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
8692

8793
test("remove entire CaseWhen if the first branch is always true") {
8894
assertEquivalent(
89-
CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
95+
CaseWhen(trueBranch :: normalBranch1 :: nullBranch :: Nil, None),
9096
Literal(5))
9197

9298
// Test branch elimination and simplification in combination
9399
assertEquivalent(
94-
CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
100+
CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch1
95101
:: Nil, None),
96102
Literal(5))
97103

98104
// Make sure this doesn't trigger if there is a non-foldable branch before the true branch
99105
assertEquivalent(
100-
CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
101-
CaseWhen(normalBranch :: trueBranch :: Nil, None))
106+
CaseWhen(normalBranch1 :: trueBranch :: normalBranch1 :: Nil, None),
107+
CaseWhen(normalBranch1 :: trueBranch :: Nil, None))
102108
}
103109

104110
test("simplify CaseWhen, prune branches following a definite true") {
105111
assertEquivalent(
106-
CaseWhen(normalBranch :: unreachableBranch ::
112+
CaseWhen(normalBranch1 :: unreachableBranch ::
107113
unreachableBranch :: nullBranch ::
108-
trueBranch :: normalBranch ::
114+
trueBranch :: normalBranch1 ::
109115
Nil,
110116
None),
111-
CaseWhen(normalBranch :: trueBranch :: Nil, None))
117+
CaseWhen(normalBranch1 :: trueBranch :: Nil, None))
112118
}
113119
}

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2813,4 +2813,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
28132813
checkAnswer(df, Seq(Row(3, 99, 1)))
28142814
}
28152815
}
2816+
2817+
test("SPARK-24892: simplify `CaseWhen` to `If` when there is only one branch") {
2818+
withTable("t") {
2819+
Seq(Some(1), null, Some(3)).toDF("a").write.saveAsTable("t")
2820+
2821+
val plan1 = sql("select case when a is null then 1 end col1 from t")
2822+
val plan2 = sql("select if(a is null, 1, null) col1 from t")
2823+
2824+
checkAnswer(plan1, Row(null) :: Row(1) :: Row(null) :: Nil)
2825+
comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan)
2826+
}
2827+
}
28162828
}

0 commit comments

Comments
 (0)