-
Notifications
You must be signed in to change notification settings - Fork 28.5k
[SPARK-24892] [SQL] Simplify CaseWhen
to If
when there is only one branch
#21850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
val cond = branches.head._1 | ||
val trueValue = branches.head._2 | ||
val falseValue = elseValue.getOrElse(Literal(null, trueValue.dataType)) | ||
If(cond, trueValue, falseValue) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we simplify more like the following pattern?
- case CaseWhen(branches, elseValue) if branches.length == 1 =>
- val cond = branches.head._1
- val trueValue = branches.head._2
- val falseValue = elseValue.getOrElse(Literal(null, trueValue.dataType))
- If(cond, trueValue, falseValue)
+ case e @ CaseWhen((cond, branchValue) :: Nil, elseValue) =>
+ If(cond, branchValue, elseValue.getOrElse(Literal(null, e.dataType)))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks!
Test build #93456 has finished for PR 21850 at commit
|
Test build #93462 has finished for PR 21850 at commit
|
retest this please |
Test build #93468 has finished for PR 21850 at commit
|
retest this please |
1 similar comment
retest this please |
Test build #93484 has finished for PR 21850 at commit
|
retest this please |
Test build #93492 has finished for PR 21850 at commit
|
@@ -61,7 +61,17 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { | |||
// i.e. removing branches whose conditions are always false | |||
assertEquivalent( | |||
CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None), | |||
CaseWhen(normalBranch :: Nil, None)) | |||
If(normalBranch._1, normalBranch._2, Literal(null, normalBranch._2.dataType))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add an extra branch into this test, so it won't be optimized to single branch CaseWhen? Otherwise this test is changed from its original purpose.
@@ -414,6 +414,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { | |||
// these branches can be pruned away | |||
val (h, t) = branches.span(_._1 != TrueLiteral) | |||
CaseWhen( h :+ t.head, None) | |||
|
|||
case CaseWhen((cond, branchValue) :: Nil, elseValue) => | |||
If(cond, branchValue, elseValue.getOrElse(Literal(null, branchValue.dataType))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you post the difference of generated JAVA code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before:
== Parsed Logical Plan ==
'Project [CASE WHEN isnull('a) THEN 1 END AS col1#181]
+- 'UnresolvedRelation
== Optimized Logical Plan ==
Project [CASE WHEN isnull(a#182) THEN 1 END AS col1#181]
+- Relation[a#182] parquet
Generated Java code
/* 043 */ protected void processNext() throws java.io.IOException {
/* 044 */ if (scan_mutableStateArray_1[0] == null) {
/* 045 */ scan_nextBatch_0();
/* 046 */ }
/* 047 */ while (scan_mutableStateArray_1[0] != null) {
/* 048 */ int scan_numRows_0 = scan_mutableStateArray_1[0].numRows();
/* 049 */ int scan_localEnd_0 = scan_numRows_0 - scan_batchIdx_0;
/* 050 */ for (int scan_localIdx_0 = 0; scan_localIdx_0 < scan_localEnd_0; scan_localIdx_0++) {
/* 051 */ int scan_rowIdx_0 = scan_batchIdx_0 + scan_localIdx_0;
/* 052 */ byte project_caseWhenResultState_0 = -1;
/* 053 */ do {
/* 054 */ boolean scan_isNull_0 = scan_mutableStateArray_2[0].isNullAt(scan_rowIdx_0);
/* 055 */ int scan_value_0 = scan_isNull_0 ? -1 : (scan_mutableStateArray_2[0].getInt(scan_rowIdx_0));
/* 056 */ if (!false && scan_isNull_0) {
/* 057 */ project_caseWhenResultState_0 = (byte)(false ? 1 : 0);
/* 058 */ project_project_value_0_0 = 1;
/* 059 */ continue;
/* 060 */ }
/* 061 */
/* 062 */ } while (false);
/* 063 */ // TRUE if any condition is met and the result is null, or no any condition is met.
/* 064 */ final boolean project_isNull_0 = (project_caseWhenResultState_0 != 0);
/* 065 */ scan_mutableStateArray_3[1].reset();
/* 066 */
/* 067 */ scan_mutableStateArray_3[1].zeroOutNullBytes();
/* 068 */
/* 069 */ if (project_isNull_0) {
/* 070 */ scan_mutableStateArray_3[1].setNullAt(0);
/* 071 */ } else {
/* 072 */ scan_mutableStateArray_3[1].write(0, project_project_value_0_0);
/* 073 */ }
/* 074 */ append((scan_mutableStateArray_3[1].getRow()));
/* 075 */ if (shouldStop()) { scan_batchIdx_0 = scan_rowIdx_0 + 1; return; }
/* 076 */ }
/* 077 */ scan_batchIdx_0 = scan_numRows_0;
/* 078 */ scan_mutableStateArray_1[0] = null;
/* 079 */ scan_nextBatch_0();
/* 080 */ }
/* 081 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* scanTime */).add(scan_scanTime_0 / (1000 * 1000));
/* 082 */ scan_scanTime_0 = 0;
/* 083 */ }
After:
== Parsed Logical Plan ==
'Project [CASE WHEN isnull('a) THEN 1 END AS b#186]
+- 'UnresolvedRelation `tddddd`
== Optimized Logical Plan ==
Project [if (isnull(a#187)) 1 else null AS b#186]
+- Relation[a#187,b#188] parquet
Generated Java code:
/* 042 */ protected void processNext() throws java.io.IOException {
/* 043 */ if (scan_mutableStateArray_1[0] == null) {
/* 044 */ scan_nextBatch_0();
/* 045 */ }
/* 046 */ while (scan_mutableStateArray_1[0] != null) {
/* 047 */ int scan_numRows_0 = scan_mutableStateArray_1[0].numRows();
/* 048 */ int scan_localEnd_0 = scan_numRows_0 - scan_batchIdx_0;
/* 049 */ for (int scan_localIdx_0 = 0; scan_localIdx_0 < scan_localEnd_0; scan_localIdx_0++) {
/* 050 */ int scan_rowIdx_0 = scan_batchIdx_0 + scan_localIdx_0;
/* 051 */ boolean scan_isNull_0 = scan_mutableStateArray_2[0].isNullAt(scan_rowIdx_0);
/* 052 */ int scan_value_0 = scan_isNull_0 ? -1 : (scan_mutableStateArray_2[0].getInt(scan_rowIdx_0));
/* 053 */ boolean project_isNull_0 = false;
/* 054 */ int project_value_0 = -1;
/* 055 */ if (!false && scan_isNull_0) {
/* 056 */ project_isNull_0 = false;
/* 057 */ project_value_0 = 1;
/* 058 */ } else {
/* 059 */ project_isNull_0 = true;
/* 060 */ project_value_0 = -1;
/* 061 */ }
/* 062 */ scan_mutableStateArray_3[1].reset();
/* 063 */
/* 064 */ scan_mutableStateArray_3[1].zeroOutNullBytes();
/* 065 */
/* 066 */ if (project_isNull_0) {
/* 067 */ scan_mutableStateArray_3[1].setNullAt(0);
/* 068 */ } else {
/* 069 */ scan_mutableStateArray_3[1].write(0, project_value_0);
/* 070 */ }
/* 071 */ append((scan_mutableStateArray_3[1].getRow()));
/* 072 */ if (shouldStop()) { scan_batchIdx_0 = scan_rowIdx_0 + 1; return; }
/* 073 */ }
/* 074 */ scan_batchIdx_0 = scan_numRows_0;
/* 075 */ scan_mutableStateArray_1[0] = null;
/* 076 */ scan_nextBatch_0();
/* 077 */ }
/* 078 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* scanTime */).add(scan_scanTime_0 / (1000 * 1000));
/* 079 */ scan_scanTime_0 = 0;
/* 080 */ }
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look like not much difference in term of performance, but If
primitive has more opportunities for further optimization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For CaseWhen case, looks like there is an extra do while loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, CaseWhen
has additional project_project_value_0_0
at outside.
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */ private Object[] references;
/* 008 */ private scala.collection.Iterator[] inputs;
/* 009 */ private scala.collection.Iterator inputadapter_input_0;
/* 010 */ private int project_project_value_0_0;
|
||
case CaseWhen(branches, elseValue) if branches.length == 1 => | ||
// Using pattern matching like `CaseWhen((cond, branchValue) :: Nil, elseValue)` will not | ||
// work since the implementation of `branches` can be `ArrayBuffer`. A full test is in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, when is itArrayBuffer
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sql("select case when a is null then 1 end col1 from t")
will create branches
with ArrayBuffer
implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can change it to immutable List
to avoid confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about:
case CaseWhen(Seq((cond, trueValue)), elseValue) =>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ueshin thanks! The code is much more cleaner.
Test build #93521 has finished for PR 21850 at commit
|
retest this please |
Test build #93525 has finished for PR 21850 at commit
|
LGTM |
Personally, I do not think we need this extra case.
Could you explain more? |
|
||
checkAnswer(plan1, Row(null) :: Row(1) :: Row(null) :: Nil) | ||
comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding this higher level test, too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, LGTM.
Test build #93554 has finished for PR 21850 at commit
|
retest this please. |
@gatorsmile All the new rules added into But there will be time that we only add For example, the new rule that short-circuiting the |
Test build #93564 has finished for PR 21850 at commit
|
@@ -414,6 +414,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { | |||
// these branches can be pruned away | |||
val (h, t) = branches.span(_._1 != TrueLiteral) | |||
CaseWhen( h :+ t.head, None) | |||
|
|||
case CaseWhen(Seq((cond, trueValue)), elseValue) => | |||
If(cond, trueValue, elseValue.getOrElse(Literal(null, trueValue.dataType))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think If
is faster than CaseWhen
, can you explain more about "further optimization"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The generated Java code is slightly simpler, but I agree there should not have any performance gain. Being said that, once CaseWhen
is converted into If
, this condition expression will be benefited from the optimization rules in If
which may not be implemented for CaseWhen
case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optimization rules in If which may not be implemented for CaseWhen case.
shall we just implement more optimizer rules for CASE WHEN to cover all the cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's revisit this PR later, and we should always try to add CASE WHEN version for parity.
Here is the one for case when.
#21852
Test build #128016 has finished for PR 21850 at commit
|
e2b0e96
to
dd584a0
Compare
Test build #128070 has finished for PR 21850 at commit
|
@@ -200,13 +200,15 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { | |||
|
|||
test("inability to replace null in non-boolean values of CaseWhen") { | |||
val nestedCaseWhen = CaseWhen( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those tests are modified and one branch is added in CaseWhen
to avoid new rule converting CaseWhen
to If
for better test coverage.
@@ -505,6 +505,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { | |||
} else { | |||
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) | |||
} | |||
|
|||
case CaseWhen(Seq((cond, trueValue)), elseValue) => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it better we limit to BooleanType
case? I.e.,
case cw @ CaseWhen(Seq((cond, trueValue)), elseValue) if cw.dataType == BooleanType =>
The reason is because mostly the further optimization comes from #29567, and it is for boolean type case only.
Or just rewrite it similarly like #29567?
case CaseWhen(Seq((cond, l @ Literal(null, _))), FalseLiteral) if !cond.nullable => ...
case CaseWhen(Seq((cond, l @ Literal(null, _))), TrueLiteral) if !cond.nullable => ...
Test build #128112 has finished for PR 21850 at commit
|
We're closing this PR because it hasn't been updated in a while. This isn't a judgement on the merit of the PR in any way. It's just a way of keeping the PR queue manageable. |
What changes were proposed in this pull request?
After the rule of removing the unreachable branches, it could be only one branch left. In this situation,
CaseWhen
can be converted toIf
to do further optimization.How was this patch tested?
Tests added.