Skip to content

[SPARK-24892] [SQL] Simplify CaseWhen to If when there is only one branch #21850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

dbtsai
Copy link
Member

@dbtsai dbtsai commented Jul 23, 2018

What changes were proposed in this pull request?

After the rule of removing the unreachable branches, it could be only one branch left. In this situation, CaseWhen can be converted to If to do further optimization.

How was this patch tested?

Tests added.

val cond = branches.head._1
val trueValue = branches.head._2
val falseValue = elseValue.getOrElse(Literal(null, trueValue.dataType))
If(cond, trueValue, falseValue)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simplify more like the following pattern?

-      case CaseWhen(branches, elseValue) if branches.length == 1 =>
-        val cond = branches.head._1
-        val trueValue = branches.head._2
-        val falseValue = elseValue.getOrElse(Literal(null, trueValue.dataType))
-        If(cond, trueValue, falseValue)

+      case e @ CaseWhen((cond, branchValue) :: Nil, elseValue) =>
+        If(cond, branchValue, elseValue.getOrElse(Literal(null, e.dataType)))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks!

@SparkQA
Copy link

SparkQA commented Jul 23, 2018

Test build #93456 has finished for PR 21850 at commit 18e2d8d.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@dbtsai
Copy link
Member Author

dbtsai commented Jul 23, 2018

@cloud-fan and @gatorsmile

@SparkQA
Copy link

SparkQA commented Jul 24, 2018

Test build #93462 has finished for PR 21850 at commit a9c97ce.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@dbtsai
Copy link
Member Author

dbtsai commented Jul 24, 2018

retest this please

@SparkQA
Copy link

SparkQA commented Jul 24, 2018

Test build #93468 has finished for PR 21850 at commit a9c97ce.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@dbtsai
Copy link
Member Author

dbtsai commented Jul 24, 2018

retest this please

1 similar comment
@HyukjinKwon
Copy link
Member

retest this please

@SparkQA
Copy link

SparkQA commented Jul 24, 2018

Test build #93484 has finished for PR 21850 at commit a9c97ce.

  • This patch fails due to an unknown error code, -9.
  • This patch merges cleanly.
  • This patch adds no public classes.

@HyukjinKwon
Copy link
Member

retest this please

@SparkQA
Copy link

SparkQA commented Jul 24, 2018

Test build #93492 has finished for PR 21850 at commit a9c97ce.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -61,7 +61,17 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
// i.e. removing branches whose conditions are always false
assertEquivalent(
CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
CaseWhen(normalBranch :: Nil, None))
If(normalBranch._1, normalBranch._2, Literal(null, normalBranch._2.dataType)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add an extra branch into this test, so it won't be optimized to single branch CaseWhen? Otherwise this test is changed from its original purpose.

@@ -414,6 +414,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
// these branches can be pruned away
val (h, t) = branches.span(_._1 != TrueLiteral)
CaseWhen( h :+ t.head, None)

case CaseWhen((cond, branchValue) :: Nil, elseValue) =>
If(cond, branchValue, elseValue.getOrElse(Literal(null, branchValue.dataType)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you post the difference of generated JAVA code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before:

== Parsed Logical Plan ==
'Project [CASE WHEN isnull('a) THEN 1 END AS col1#181]
+- 'UnresolvedRelation

== Optimized Logical Plan ==
Project [CASE WHEN isnull(a#182) THEN 1 END AS col1#181]
+- Relation[a#182] parquet

Generated Java code

/* 043 */   protected void processNext() throws java.io.IOException {
/* 044 */     if (scan_mutableStateArray_1[0] == null) {
/* 045 */       scan_nextBatch_0();
/* 046 */     }
/* 047 */     while (scan_mutableStateArray_1[0] != null) {
/* 048 */       int scan_numRows_0 = scan_mutableStateArray_1[0].numRows();
/* 049 */       int scan_localEnd_0 = scan_numRows_0 - scan_batchIdx_0;
/* 050 */       for (int scan_localIdx_0 = 0; scan_localIdx_0 < scan_localEnd_0; scan_localIdx_0++) {
/* 051 */         int scan_rowIdx_0 = scan_batchIdx_0 + scan_localIdx_0;
/* 052 */         byte project_caseWhenResultState_0 = -1;
/* 053 */         do {
/* 054 */           boolean scan_isNull_0 = scan_mutableStateArray_2[0].isNullAt(scan_rowIdx_0);
/* 055 */           int scan_value_0 = scan_isNull_0 ? -1 : (scan_mutableStateArray_2[0].getInt(scan_rowIdx_0));
/* 056 */           if (!false && scan_isNull_0) {
/* 057 */             project_caseWhenResultState_0 = (byte)(false ? 1 : 0);
/* 058 */             project_project_value_0_0 = 1;
/* 059 */             continue;
/* 060 */           }
/* 061 */
/* 062 */         } while (false);
/* 063 */         // TRUE if any condition is met and the result is null, or no any condition is met.
/* 064 */         final boolean project_isNull_0 = (project_caseWhenResultState_0 != 0);
/* 065 */         scan_mutableStateArray_3[1].reset();
/* 066 */
/* 067 */         scan_mutableStateArray_3[1].zeroOutNullBytes();
/* 068 */
/* 069 */         if (project_isNull_0) {
/* 070 */           scan_mutableStateArray_3[1].setNullAt(0);
/* 071 */         } else {
/* 072 */           scan_mutableStateArray_3[1].write(0, project_project_value_0_0);
/* 073 */         }
/* 074 */         append((scan_mutableStateArray_3[1].getRow()));
/* 075 */         if (shouldStop()) { scan_batchIdx_0 = scan_rowIdx_0 + 1; return; }
/* 076 */       }
/* 077 */       scan_batchIdx_0 = scan_numRows_0;
/* 078 */       scan_mutableStateArray_1[0] = null;
/* 079 */       scan_nextBatch_0();
/* 080 */     }
/* 081 */     ((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* scanTime */).add(scan_scanTime_0 / (1000 * 1000));
/* 082 */     scan_scanTime_0 = 0;
/* 083 */   }

After:

== Parsed Logical Plan ==
'Project [CASE WHEN isnull('a) THEN 1 END AS b#186]
+- 'UnresolvedRelation `tddddd`

== Optimized Logical Plan ==
Project [if (isnull(a#187)) 1 else null AS b#186]
+- Relation[a#187,b#188] parquet

Generated Java code:

/* 042 */   protected void processNext() throws java.io.IOException {
/* 043 */     if (scan_mutableStateArray_1[0] == null) {
/* 044 */       scan_nextBatch_0();
/* 045 */     }
/* 046 */     while (scan_mutableStateArray_1[0] != null) {
/* 047 */       int scan_numRows_0 = scan_mutableStateArray_1[0].numRows();
/* 048 */       int scan_localEnd_0 = scan_numRows_0 - scan_batchIdx_0;
/* 049 */       for (int scan_localIdx_0 = 0; scan_localIdx_0 < scan_localEnd_0; scan_localIdx_0++) {
/* 050 */         int scan_rowIdx_0 = scan_batchIdx_0 + scan_localIdx_0;
/* 051 */         boolean scan_isNull_0 = scan_mutableStateArray_2[0].isNullAt(scan_rowIdx_0);
/* 052 */         int scan_value_0 = scan_isNull_0 ? -1 : (scan_mutableStateArray_2[0].getInt(scan_rowIdx_0));
/* 053 */         boolean project_isNull_0 = false;
/* 054 */         int project_value_0 = -1;
/* 055 */         if (!false && scan_isNull_0) {
/* 056 */           project_isNull_0 = false;
/* 057 */           project_value_0 = 1;
/* 058 */         } else {
/* 059 */           project_isNull_0 = true;
/* 060 */           project_value_0 = -1;
/* 061 */         }
/* 062 */         scan_mutableStateArray_3[1].reset();
/* 063 */
/* 064 */         scan_mutableStateArray_3[1].zeroOutNullBytes();
/* 065 */
/* 066 */         if (project_isNull_0) {
/* 067 */           scan_mutableStateArray_3[1].setNullAt(0);
/* 068 */         } else {
/* 069 */           scan_mutableStateArray_3[1].write(0, project_value_0);
/* 070 */         }
/* 071 */         append((scan_mutableStateArray_3[1].getRow()));
/* 072 */         if (shouldStop()) { scan_batchIdx_0 = scan_rowIdx_0 + 1; return; }
/* 073 */       }
/* 074 */       scan_batchIdx_0 = scan_numRows_0;
/* 075 */       scan_mutableStateArray_1[0] = null;
/* 076 */       scan_nextBatch_0();
/* 077 */     }
/* 078 */     ((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* scanTime */).add(scan_scanTime_0 / (1000 * 1000));
/* 079 */     scan_scanTime_0 = 0;
/* 080 */   }

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look like not much difference in term of performance, but If primitive has more opportunities for further optimization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For CaseWhen case, looks like there is an extra do while loop?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, CaseWhen has additional project_project_value_0_0 at outside.

/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private scala.collection.Iterator inputadapter_input_0;
/* 010 */   private int project_project_value_0_0;


case CaseWhen(branches, elseValue) if branches.length == 1 =>
// Using pattern matching like `CaseWhen((cond, branchValue) :: Nil, elseValue)` will not
// work since the implementation of `branches` can be `ArrayBuffer`. A full test is in
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, when is itArrayBuffer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sql("select case when a is null then 1 end col1 from t") will create branches with ArrayBuffer implementation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can change it to immutable List to avoid confusion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about:

case CaseWhen(Seq((cond, trueValue)), elseValue) =>

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for @ueshin 's suggestion. And, sorry for this trouble, @dbtsai . :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ueshin thanks! The code is much more cleaner.

@SparkQA
Copy link

SparkQA commented Jul 25, 2018

Test build #93521 has finished for PR 21850 at commit 59fada7.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@dbtsai
Copy link
Member Author

dbtsai commented Jul 25, 2018

retest this please

@SparkQA
Copy link

SparkQA commented Jul 25, 2018

Test build #93525 has finished for PR 21850 at commit 59fada7.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@viirya
Copy link
Member

viirya commented Jul 25, 2018

LGTM

@gatorsmile
Copy link
Member

Personally, I do not think we need this extra case.

If primitive has more opportunities for further optimization.

Could you explain more?


checkAnswer(plan1, Row(null) :: Row(1) :: Row(null) :: Nil)
comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this higher level test, too.

Copy link
Member

@dongjoon-hyun dongjoon-hyun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, LGTM.

@SparkQA
Copy link

SparkQA commented Jul 25, 2018

Test build #93554 has finished for PR 21850 at commit e2b0e96.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@viirya
Copy link
Member

viirya commented Jul 25, 2018

retest this please.

@dbtsai
Copy link
Member Author

dbtsai commented Jul 25, 2018

@gatorsmile All the new rules added into If should always have CaseWhen version.

But there will be time that we only add If version, or it only makes sense to have If version.

For example, the new rule that short-circuiting the If when both branches are the same is not yet in CaseWhen. Another rule I am working on is when a conditional expression in Filter or Join, if any of the output of the branches contains Literal(null, _), it can be replaced by FalseLiteral. I only implemented for If so far for our use-case.

@SparkQA
Copy link

SparkQA commented Jul 25, 2018

Test build #93564 has finished for PR 21850 at commit e2b0e96.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -414,6 +414,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
// these branches can be pruned away
val (h, t) = branches.span(_._1 != TrueLiteral)
CaseWhen( h :+ t.head, None)

case CaseWhen(Seq((cond, trueValue)), elseValue) =>
If(cond, trueValue, elseValue.getOrElse(Literal(null, trueValue.dataType)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think If is faster than CaseWhen, can you explain more about "further optimization"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generated Java code is slightly simpler, but I agree there should not have any performance gain. Being said that, once CaseWhen is converted into If, this condition expression will be benefited from the optimization rules in If which may not be implemented for CaseWhen case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimization rules in If which may not be implemented for CaseWhen case.

shall we just implement more optimizer rules for CASE WHEN to cover all the cases?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's revisit this PR later, and we should always try to add CASE WHEN version for parity.

Here is the one for case when.
#21852

@SparkQA
Copy link

SparkQA commented Aug 29, 2020

Test build #128016 has finished for PR 21850 at commit e2b0e96.

  • This patch fails Scala style tests.
  • This patch does not merge cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Aug 31, 2020

Test build #128070 has finished for PR 21850 at commit dd584a0.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -200,13 +200,15 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {

test("inability to replace null in non-boolean values of CaseWhen") {
val nestedCaseWhen = CaseWhen(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those tests are modified and one branch is added in CaseWhen to avoid new rule converting CaseWhen to If for better test coverage.

@@ -505,6 +505,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
} else {
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue)))
}

case CaseWhen(Seq((cond, trueValue)), elseValue) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better we limit to BooleanType case? I.e.,

case cw @ CaseWhen(Seq((cond, trueValue)), elseValue) if cw.dataType == BooleanType =>

The reason is because mostly the further optimization comes from #29567, and it is for boolean type case only.

Or just rewrite it similarly like #29567?

case CaseWhen(Seq((cond, l @ Literal(null, _))), FalseLiteral) if !cond.nullable => ...
case CaseWhen(Seq((cond, l @ Literal(null, _))), TrueLiteral) if !cond.nullable => ...

@SparkQA
Copy link

SparkQA commented Aug 31, 2020

Test build #128112 has finished for PR 21850 at commit d34e908.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@github-actions
Copy link

We're closing this PR because it hasn't been updated in a while. This isn't a judgement on the merit of the PR in any way. It's just a way of keeping the PR queue manageable.
If you'd like to revive this PR, please reopen it and ask a committer to remove the Stale tag!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants