@@ -553,41 +553,68 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
553
553
foldables.nonEmpty && others.length < 2
554
554
}
555
555
556
+ // Not all UnaryExpression can be pushed into (if / case) branches, e.g. Alias.
557
+ private def supportedUnaryExpression (e : UnaryExpression ): Boolean = e match {
558
+ case _ : IsNull | _ : IsNotNull => true
559
+ case _ : UnaryMathExpression | _ : Abs | _ : Bin | _ : Factorial | _ : Hex => true
560
+ case _ : String2StringExpression | _ : Ascii | _ : Base64 | _ : BitLength | _ : Chr | _ : Length =>
561
+ true
562
+ case _ : CastBase => true
563
+ case _ : GetDateField | _ : LastDay => true
564
+ case _ : ExtractIntervalPart => true
565
+ case _ : ArraySetLike => true
566
+ case _ : ExtractValue => true
567
+ case _ => false
568
+ }
569
+
570
+ // Not all BinaryExpression can be pushed into (if / case) branches.
571
+ private def supportedBinaryExpression (e : BinaryExpression ): Boolean = e match {
572
+ case _ : BinaryComparison | _ : StringPredicate | _ : StringRegexExpression => true
573
+ case _ : BinaryArithmetic => true
574
+ case _ : BinaryMathExpression => true
575
+ case _ : AddMonths | _ : DateAdd | _ : DateAddInterval | _ : DateDiff | _ : DateSub => true
576
+ case _ : FindInSet | _ : RoundBase => true
577
+ case _ => false
578
+ }
579
+
556
580
def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
557
581
case q : LogicalPlan => q transformExpressionsUp {
558
- case a : Alias => a // Skip an alias.
559
582
case u @ UnaryExpression (i @ If (_, trueValue, falseValue))
560
- if atMostOneUnfoldable(Seq (trueValue, falseValue)) =>
583
+ if supportedUnaryExpression(u) && atMostOneUnfoldable(Seq (trueValue, falseValue)) =>
561
584
i.copy(
562
585
trueValue = u.withNewChildren(Array (trueValue)),
563
586
falseValue = u.withNewChildren(Array (falseValue)))
564
587
565
588
case u @ UnaryExpression (c @ CaseWhen (branches, elseValue))
566
- if atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
589
+ if supportedUnaryExpression(u) && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
567
590
c.copy(
568
591
branches.map(e => e.copy(_2 = u.withNewChildren(Array (e._2)))),
569
592
elseValue.map(e => u.withNewChildren(Array (e))))
570
593
571
594
case b @ BinaryExpression (i @ If (_, trueValue, falseValue), right)
572
- if right.foldable && atMostOneUnfoldable(Seq (trueValue, falseValue)) =>
595
+ if supportedBinaryExpression(b) && right.foldable &&
596
+ atMostOneUnfoldable(Seq (trueValue, falseValue)) =>
573
597
i.copy(
574
598
trueValue = b.withNewChildren(Array (trueValue, right)),
575
599
falseValue = b.withNewChildren(Array (falseValue, right)))
576
600
577
601
case b @ BinaryExpression (left, i @ If (_, trueValue, falseValue))
578
- if left.foldable && atMostOneUnfoldable(Seq (trueValue, falseValue)) =>
602
+ if supportedBinaryExpression(b) && left.foldable &&
603
+ atMostOneUnfoldable(Seq (trueValue, falseValue)) =>
579
604
i.copy(
580
605
trueValue = b.withNewChildren(Array (left, trueValue)),
581
606
falseValue = b.withNewChildren(Array (left, falseValue)))
582
607
583
608
case b @ BinaryExpression (c @ CaseWhen (branches, elseValue), right)
584
- if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
609
+ if supportedBinaryExpression(b) && right.foldable &&
610
+ atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
585
611
c.copy(
586
612
branches.map(e => e.copy(_2 = b.withNewChildren(Array (e._2, right)))),
587
613
elseValue.map(e => b.withNewChildren(Array (e, right))))
588
614
589
615
case b @ BinaryExpression (left, c @ CaseWhen (branches, elseValue))
590
- if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
616
+ if supportedBinaryExpression(b) && left.foldable &&
617
+ atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
591
618
c.copy(
592
619
branches.map(e => e.copy(_2 = b.withNewChildren(Array (left, e._2)))),
593
620
elseValue.map(e => b.withNewChildren(Array (left, e))))
0 commit comments