Skip to content

Commit f705037

Browse files
dongjoon-hyunrxin
authored andcommitted
[SPARK-14338][SQL] Improve SimplifyConditionals rule to handle null in IF/CASEWHEN
## What changes were proposed in this pull request? Currently, `SimplifyConditionals` handles `true` and `false` to optimize branches. This PR improves `SimplifyConditionals` to take advantage of `null` conditions for `if` and `CaseWhen` expressions, too. **Before** ``` scala> sql("SELECT IF(null, 1, 0)").explain() == Physical Plan == WholeStageCodegen : +- Project [if (null) 1 else 0 AS (IF(CAST(NULL AS BOOLEAN), 1, 0))alteryx#4] : +- INPUT +- Scan OneRowRelation[] scala> sql("select case when cast(null as boolean) then 1 else 2 end").explain() == Physical Plan == WholeStageCodegen : +- Project [CASE WHEN null THEN 1 ELSE 2 END AS CASE WHEN CAST(NULL AS BOOLEAN) THEN 1 ELSE 2 END#14] : +- INPUT +- Scan OneRowRelation[] ``` **After** ``` scala> sql("SELECT IF(null, 1, 0)").explain() == Physical Plan == WholeStageCodegen : +- Project [0 AS (IF(CAST(NULL AS BOOLEAN), 1, 0))alteryx#4] : +- INPUT +- Scan OneRowRelation[] scala> sql("select case when cast(null as boolean) then 1 else 2 end").explain() == Physical Plan == WholeStageCodegen : +- Project [2 AS CASE WHEN CAST(NULL AS BOOLEAN) THEN 1 ELSE 2 END#4] : +- INPUT +- Scan OneRowRelation[] ``` **Hive** ``` hive> select if(null,1,2); OK 2 hive> select case when cast(null as boolean) then 1 else 2 end; OK 2 ``` ## How was this patch tested? Pass the Jenkins tests (including new extended test cases). Author: Dongjoon Hyun <dongjoon@apache.org> Closes apache#12122 from dongjoon-hyun/SPARK-14338.
1 parent a3e2935 commit f705037

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ object LikeSimplification extends Rule[LogicalPlan] {
527527
* Null value propagation from bottom to top of the expression tree.
528528
*/
529529
object NullPropagation extends Rule[LogicalPlan] {
530-
def nonNullLiteral(e: Expression): Boolean = e match {
530+
private def nonNullLiteral(e: Expression): Boolean = e match {
531531
case Literal(null, _) => false
532532
case _ => true
533533
}
@@ -773,17 +773,24 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
773773
* Simplifies conditional expressions (if / case).
774774
*/
775775
object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
776+
private def falseOrNullLiteral(e: Expression): Boolean = e match {
777+
case FalseLiteral => true
778+
case Literal(null, _) => true
779+
case _ => false
780+
}
781+
776782
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
777783
case q: LogicalPlan => q transformExpressionsUp {
778784
case If(TrueLiteral, trueValue, _) => trueValue
779785
case If(FalseLiteral, _, falseValue) => falseValue
786+
case If(Literal(null, _), _, falseValue) => falseValue
780787

781-
case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) =>
788+
case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
782789
// If there are branches that are always false, remove them.
783790
// If there are no more branches left, just use the else value.
784791
// Note that these two are handled together here in a single case statement because
785792
// otherwise we cannot determine the data type for the elseValue if it is None (i.e. null).
786-
val newBranches = branches.filter(_._1 != FalseLiteral)
793+
val newBranches = branches.filter(x => !falseOrNullLiteral(x._1))
787794
if (newBranches.isEmpty) {
788795
elseValue.getOrElse(Literal.create(null, e.dataType))
789796
} else {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
2323
import org.apache.spark.sql.catalyst.plans.PlanTest
2424
import org.apache.spark.sql.catalyst.plans.logical._
2525
import org.apache.spark.sql.catalyst.rules._
26-
import org.apache.spark.sql.types.IntegerType
26+
import org.apache.spark.sql.types.{IntegerType, NullType}
2727

2828

2929
class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
@@ -41,6 +41,7 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
4141
private val trueBranch = (TrueLiteral, Literal(5))
4242
private val normalBranch = (NonFoldableLiteral(true), Literal(10))
4343
private val unreachableBranch = (FalseLiteral, Literal(20))
44+
private val nullBranch = (Literal(null, NullType), Literal(30))
4445

4546
test("simplify if") {
4647
assertEquivalent(
@@ -50,18 +51,22 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
5051
assertEquivalent(
5152
If(FalseLiteral, Literal(10), Literal(20)),
5253
Literal(20))
54+
55+
assertEquivalent(
56+
If(Literal(null, NullType), Literal(10), Literal(20)),
57+
Literal(20))
5358
}
5459

5560
test("remove unreachable branches") {
5661
// i.e. removing branches whose conditions are always false
5762
assertEquivalent(
58-
CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: Nil, None),
63+
CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
5964
CaseWhen(normalBranch :: Nil, None))
6065
}
6166

6267
test("remove entire CaseWhen if only the else branch is reachable") {
6368
assertEquivalent(
64-
CaseWhen(unreachableBranch :: unreachableBranch :: Nil, Some(Literal(30))),
69+
CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),
6570
Literal(30))
6671

6772
assertEquivalent(
@@ -71,12 +76,13 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
7176

7277
test("remove entire CaseWhen if the first branch is always true") {
7378
assertEquivalent(
74-
CaseWhen(trueBranch :: normalBranch :: Nil, None),
79+
CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
7580
Literal(5))
7681

7782
// Test branch elimination and simplification in combination
7883
assertEquivalent(
79-
CaseWhen(unreachableBranch :: unreachableBranch:: trueBranch :: normalBranch :: Nil, None),
84+
CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
85+
:: Nil, None),
8086
Literal(5))
8187

8288
// Make sure this doesn't trigger if there is a non-foldable branch before the true branch

0 commit comments

Comments
 (0)