Skip to content

Commit 872107f

Browse files
wangyumcloud-fan
authored andcommitted
[SPARK-33848][SQL][FOLLOWUP] Introduce allowList for push into (if / case) branches
### What changes were proposed in this pull request? Introduce allowList push into (if / case) branches to fix potential bug. ### Why are the changes needed? Fix potential bug. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing test. Closes #30955 from wangyum/SPARK-33848-2. Authored-by: Yuming Wang <yumwang@ebay.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 3b1b209 commit 872107f

File tree

1 file changed

+34
-7
lines changed

1 file changed

+34
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -553,41 +553,68 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
553553
foldables.nonEmpty && others.length < 2
554554
}
555555

556+
// Not all UnaryExpression can be pushed into (if / case) branches, e.g. Alias.
557+
private def supportedUnaryExpression(e: UnaryExpression): Boolean = e match {
558+
case _: IsNull | _: IsNotNull => true
559+
case _: UnaryMathExpression | _: Abs | _: Bin | _: Factorial | _: Hex => true
560+
case _: String2StringExpression | _: Ascii | _: Base64 | _: BitLength | _: Chr | _: Length =>
561+
true
562+
case _: CastBase => true
563+
case _: GetDateField | _: LastDay => true
564+
case _: ExtractIntervalPart => true
565+
case _: ArraySetLike => true
566+
case _: ExtractValue => true
567+
case _ => false
568+
}
569+
570+
// Not all BinaryExpression can be pushed into (if / case) branches.
571+
private def supportedBinaryExpression(e: BinaryExpression): Boolean = e match {
572+
case _: BinaryComparison | _: StringPredicate | _: StringRegexExpression => true
573+
case _: BinaryArithmetic => true
574+
case _: BinaryMathExpression => true
575+
case _: AddMonths | _: DateAdd | _: DateAddInterval | _: DateDiff | _: DateSub => true
576+
case _: FindInSet | _: RoundBase => true
577+
case _ => false
578+
}
579+
556580
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
557581
case q: LogicalPlan => q transformExpressionsUp {
558-
case a: Alias => a // Skip an alias.
559582
case u @ UnaryExpression(i @ If(_, trueValue, falseValue))
560-
if atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
583+
if supportedUnaryExpression(u) && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
561584
i.copy(
562585
trueValue = u.withNewChildren(Array(trueValue)),
563586
falseValue = u.withNewChildren(Array(falseValue)))
564587

565588
case u @ UnaryExpression(c @ CaseWhen(branches, elseValue))
566-
if atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
589+
if supportedUnaryExpression(u) && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
567590
c.copy(
568591
branches.map(e => e.copy(_2 = u.withNewChildren(Array(e._2)))),
569592
elseValue.map(e => u.withNewChildren(Array(e))))
570593

571594
case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right)
572-
if right.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
595+
if supportedBinaryExpression(b) && right.foldable &&
596+
atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
573597
i.copy(
574598
trueValue = b.withNewChildren(Array(trueValue, right)),
575599
falseValue = b.withNewChildren(Array(falseValue, right)))
576600

577601
case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue))
578-
if left.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
602+
if supportedBinaryExpression(b) && left.foldable &&
603+
atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
579604
i.copy(
580605
trueValue = b.withNewChildren(Array(left, trueValue)),
581606
falseValue = b.withNewChildren(Array(left, falseValue)))
582607

583608
case b @ BinaryExpression(c @ CaseWhen(branches, elseValue), right)
584-
if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
609+
if supportedBinaryExpression(b) && right.foldable &&
610+
atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
585611
c.copy(
586612
branches.map(e => e.copy(_2 = b.withNewChildren(Array(e._2, right)))),
587613
elseValue.map(e => b.withNewChildren(Array(e, right))))
588614

589615
case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue))
590-
if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
616+
if supportedBinaryExpression(b) && left.foldable &&
617+
atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
591618
c.copy(
592619
branches.map(e => e.copy(_2 = b.withNewChildren(Array(left, e._2)))),
593620
elseValue.map(e => b.withNewChildren(Array(left, e))))

0 commit comments

Comments
 (0)