Skip to content

[SPARK-33845][SQL] Remove unnecessary if when trueValue and falseValue are foldable boolean types #30849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
case If(Literal(null, _), _, falseValue) => falseValue
case If(cond, TrueLiteral, FalseLiteral) => cond
case If(cond, FalseLiteral, TrueLiteral) => Not(cond)
case If(cond, trueValue, falseValue)
if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue
case If(cond, l @ Literal(null, _), FalseLiteral) if !cond.nullable => And(cond, l)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PushFoldableIntoBranchesSuite

test("Push down EqualTo through If") {
assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral)
assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral))
assertEquivalent(EqualTo(ifExp, Literal(3)), Not(a))

// Push down at most one not foldable expressions.
assertEquivalent(
Expand All @@ -67,7 +67,7 @@ class PushFoldableIntoBranchesSuite
val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(2))
assert(!nonDeterministic.deterministic)
assertEquivalent(EqualTo(nonDeterministic, Literal(2)),
If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, TrueLiteral))
GreaterThanOrEqual(Rand(1), Literal(0.5)))
assertEquivalent(EqualTo(nonDeterministic, Literal(3)),
If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, FalseLiteral))

Expand Down Expand Up @@ -102,8 +102,7 @@ class PushFoldableIntoBranchesSuite
assertEquivalent(Remainder(ifExp, Literal(4)), If(a, Literal(2), Literal(3)))
assertEquivalent(Divide(If(a, Literal(2.0), Literal(3.0)), Literal(1.0)),
If(a, Literal(2.0), Literal(3.0)))
assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral),
If(a, FalseLiteral, TrueLiteral))
assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), Not(a))
assertEquivalent(Or(If(a, FalseLiteral, TrueLiteral), TrueLiteral), TrueLiteral)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, LessThanOrEqual, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable}
Expand Down Expand Up @@ -236,12 +236,13 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
Literal(2) === nestedCaseWhen,
TrueLiteral,
FalseLiteral)
val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue)
val condition = CaseWhen(branches)
testFilter(originalCond = condition, expectedCond = condition)
testJoin(originalCond = condition, expectedCond = condition)
testDelete(originalCond = condition, expectedCond = condition)
testUpdate(originalCond = condition, expectedCond = condition)
val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue))
val expectedCond =
CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> (Literal(2) === nestedCaseWhen)))
testFilter(originalCond = condition, expectedCond = expectedCond)
testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond)
testUpdate(originalCond = condition, expectedCond = expectedCond)
}

test("inability to replace null in non-boolean branches of If inside another If") {
Expand All @@ -252,10 +253,14 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
Literal(3)),
TrueLiteral,
FalseLiteral)
testFilter(originalCond = condition, expectedCond = condition)
testJoin(originalCond = condition, expectedCond = condition)
testDelete(originalCond = condition, expectedCond = condition)
testUpdate(originalCond = condition, expectedCond = condition)
val expectedCond = Literal(5) > If(
UnresolvedAttribute("i") === Literal(15),
Literal(null, IntegerType),
Literal(3))
testFilter(originalCond = condition, expectedCond = expectedCond)
testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond)
testUpdate(originalCond = condition, expectedCond = expectedCond)
}

test("replace null in If used as a join condition") {
Expand Down Expand Up @@ -405,9 +410,9 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
val lambda1 = LambdaFunction(
function = If(cond, Literal(null, BooleanType), TrueLiteral),
arguments = lambdaArgs)
// the optimized lambda body is: if(arg > 0, false, true)
// the optimized lambda body is: if(arg > 0, false, true) => arg <= 0
val lambda2 = LambdaFunction(
function = If(cond, FalseLiteral, TrueLiteral),
function = LessThanOrEqual(condArg, Literal(0)),
arguments = lambdaArgs)
testProjection(
originalExpr = createExpr(argument, lambda1) as 'x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,20 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
If(Factorial(5) > 100L, b, nullLiteral).eval(EmptyRow))
}
}

test("SPARK-33845: remove unnecessary if when the outputs are boolean type") {
assertEquivalent(
If(IsNotNull(UnresolvedAttribute("a")), TrueLiteral, FalseLiteral),
IsNotNull(UnresolvedAttribute("a")))
assertEquivalent(
If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, TrueLiteral),
IsNull(UnresolvedAttribute("a")))

assertEquivalent(
If(GreaterThan(Rand(0), UnresolvedAttribute("a")), TrueLiteral, FalseLiteral),
GreaterThan(Rand(0), UnresolvedAttribute("a")))
assertEquivalent(
If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral),
LessThanOrEqual(Rand(0), UnresolvedAttribute("a")))
}
}