-
Notifications
You must be signed in to change notification settings - Fork 28.5k
[SPARK-24892] [SQL] Simplify CaseWhen
to If
when there is only one branch
#21850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -505,6 +505,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { | |
} else { | ||
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) | ||
} | ||
|
||
case CaseWhen(Seq((cond, trueValue)), elseValue) => | ||
If(cond, trueValue, elseValue.getOrElse(Literal(null, trueValue.dataType))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The generated Java code is slightly simpler, but I agree there should not have any performance gain. Being said that, once There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
shall we just implement more optimizer rules for CASE WHEN to cover all the cases? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's revisit this PR later, and we should always try to add CASE WHEN version for parity. Here is the one for case when. |
||
} | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -200,13 +200,15 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { | |
|
||
test("inability to replace null in non-boolean values of CaseWhen") { | ||
val nestedCaseWhen = CaseWhen( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Those tests are modified and one branch is added in |
||
Seq((UnresolvedAttribute("i") > Literal(20)) -> Literal(2)), | ||
Seq((UnresolvedAttribute("i") > Literal(20)) -> Literal(2), | ||
(UnresolvedAttribute("i") > Literal(25)) -> Literal(3)), | ||
Literal(null, IntegerType)) | ||
val branchValue = If( | ||
Literal(2) === nestedCaseWhen, | ||
TrueLiteral, | ||
FalseLiteral) | ||
val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue) | ||
val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue, | ||
UnresolvedAttribute("b").isNull -> TrueLiteral) | ||
val condition = CaseWhen(branches) | ||
testFilter(originalCond = condition, expectedCond = condition) | ||
testJoin(originalCond = condition, expectedCond = condition) | ||
|
@@ -304,7 +306,8 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { | |
val condition = GreaterThan( | ||
UnresolvedAttribute("i"), | ||
If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4))) | ||
val column = CaseWhen(Seq(condition -> Literal(5)), Literal(2)).as("out") | ||
val column = CaseWhen(Seq(condition -> Literal(5), | ||
UnresolvedAttribute("b").isNotNull -> Literal(5)), Literal(2)).as("out") | ||
testProjection(originalExpr = column, expectedExpr = column) | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3582,6 +3582,18 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark | |
checkAnswer(sql("SELECT 0 FROM ( SELECT * FROM B JOIN C USING (id)) " + | ||
"JOIN ( SELECT * FROM B JOIN C USING (id)) USING (id)"), Row(0)) | ||
} | ||
|
||
test("SPARK-24892: simplify `CaseWhen` to `If` when there is only one branch") { | ||
withTable("t") { | ||
Seq(Some(1), null, Some(3)).toDF("a").write.saveAsTable("t") | ||
|
||
val plan1 = sql("select case when a is null then 1 end col1 from t") | ||
val plan2 = sql("select if(a is null, 1, null) col1 from t") | ||
|
||
checkAnswer(plan1, Row(null) :: Row(1) :: Row(null) :: Nil) | ||
comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for adding this higher level test, too. |
||
} | ||
} | ||
|
||
case class Foo(bar: Option[String]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it better we limit to
BooleanType
case? I.e.,The reason is because mostly the further optimization comes from #29567, and it is for boolean type case only.
Or just rewrite it similarly like #29567?