Skip to content

Commit 090890c

Browse files
committed
Improve SimplifyConditionals
1 parent 2da6885 commit 090890c

File tree

4 files changed

+38
-33
lines changed

4 files changed

+38
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -560,15 +560,13 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
560560
if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
561561
c.copy(
562562
branches.map(e => e.copy(_2 = b.makeCopy(Array(e._2, right)))),
563-
elseValue.orElse(Some(Literal.create(null, right.dataType)))
564-
.map(e => b.makeCopy(Array(e, right))))
563+
elseValue.map(e => b.makeCopy(Array(e, right))))
565564

566565
case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue))
567566
if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
568567
c.copy(
569568
branches.map(e => e.copy(_2 = b.makeCopy(Array(left, e._2)))),
570-
elseValue.orElse(Some(Literal.create(null, left.dataType)))
571-
.map(e => b.makeCopy(Array(left, e))))
569+
elseValue.map(e => b.makeCopy(Array(left, e))))
572570
}
573571
}
574572
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,17 +124,15 @@ class PushFoldableIntoBranchesSuite
124124
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral)))
125125
assertEquivalent(
126126
EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None), Literal(4)),
127-
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), nullBoolean))
128-
assertEquivalent(
129-
EqualTo(CaseWhen(Seq((a, nullInt), (c, nullInt)), None), Literal(4)), nullBoolean)
127+
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None))
130128

131129
assertEquivalent(
132130
And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))),
133131
FalseLiteral)
134132

135133
// Push down at most one branch is not foldable expressions.
136134
assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, Literal(1))), None), Literal(1)),
137-
CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), nullBoolean))
135+
CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), None))
138136
assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1)),
139137
EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1)))
140138
assertEquivalent(EqualTo(CaseWhen(Seq((a, b)), None), Literal(1)),
@@ -225,7 +223,5 @@ class PushFoldableIntoBranchesSuite
225223
assertEquivalent(EqualTo(Literal(4), ifExp), FalseLiteral)
226224
assertEquivalent(EqualTo(Literal(4), If(a, nullInt, nullInt)), nullBoolean)
227225
assertEquivalent(EqualTo(Literal(4), caseWhen), FalseLiteral)
228-
assertEquivalent(EqualTo(Literal(4), CaseWhen(Seq((a, nullInt), (c, nullInt)), None)),
229-
nullBoolean)
230226
}
231227
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
2222
import org.apache.spark.sql.catalyst.dsl.expressions._
2323
import org.apache.spark.sql.catalyst.dsl.plans._
24-
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
24+
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, LessThanOrEqual, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
2525
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
2626
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
2727
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable}
@@ -236,12 +236,13 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
236236
Literal(2) === nestedCaseWhen,
237237
TrueLiteral,
238238
FalseLiteral)
239-
val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue)
240-
val condition = CaseWhen(branches)
241-
testFilter(originalCond = condition, expectedCond = condition)
242-
testJoin(originalCond = condition, expectedCond = condition)
243-
testDelete(originalCond = condition, expectedCond = condition)
244-
testUpdate(originalCond = condition, expectedCond = condition)
239+
val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue))
240+
val expectedCond =
241+
CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> (Literal(2) === nestedCaseWhen)))
242+
testFilter(originalCond = condition, expectedCond = expectedCond)
243+
testJoin(originalCond = condition, expectedCond = expectedCond)
244+
testDelete(originalCond = condition, expectedCond = expectedCond)
245+
testUpdate(originalCond = condition, expectedCond = expectedCond)
245246
}
246247

247248
test("inability to replace null in non-boolean branches of If inside another If") {
@@ -252,10 +253,14 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
252253
Literal(3)),
253254
TrueLiteral,
254255
FalseLiteral)
255-
testFilter(originalCond = condition, expectedCond = condition)
256-
testJoin(originalCond = condition, expectedCond = condition)
257-
testDelete(originalCond = condition, expectedCond = condition)
258-
testUpdate(originalCond = condition, expectedCond = condition)
256+
val expectedCond = Literal(5) > If(
257+
UnresolvedAttribute("i") === Literal(15),
258+
Literal(null, IntegerType),
259+
Literal(3))
260+
testFilter(originalCond = condition, expectedCond = expectedCond)
261+
testJoin(originalCond = condition, expectedCond = expectedCond)
262+
testDelete(originalCond = condition, expectedCond = expectedCond)
263+
testUpdate(originalCond = condition, expectedCond = expectedCond)
259264
}
260265

261266
test("replace null in If used as a join condition") {
@@ -405,9 +410,9 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
405410
val lambda1 = LambdaFunction(
406411
function = If(cond, Literal(null, BooleanType), TrueLiteral),
407412
arguments = lambdaArgs)
408-
// the optimized lambda body is: if(arg > 0, false, true)
413+
// the optimized lambda body is: if(arg > 0, false, true) => arg <= 0
409414
val lambda2 = LambdaFunction(
410-
function = If(cond, FalseLiteral, TrueLiteral),
415+
function = LessThanOrEqual(condArg, Literal(0)),
411416
arguments = lambdaArgs)
412417
testProjection(
413418
originalExpr = createExpr(argument, lambda1) as 'x,

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,6 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
7979
Literal(9)))
8080
}
8181

82-
test("remove unnecessary if when the outputs are boolean type") {
83-
assertEquivalent(
84-
If(IsNotNull(UnresolvedAttribute("a")), TrueLiteral, FalseLiteral),
85-
IsNotNull(UnresolvedAttribute("a")))
86-
87-
assertEquivalent(
88-
If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, TrueLiteral),
89-
IsNull(UnresolvedAttribute("a")))
90-
}
91-
9282
test("remove unreachable branches") {
9383
// i.e. removing branches whose conditions are always false
9484
assertEquivalent(
@@ -209,4 +199,20 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
209199
If(Factorial(5) > 100L, b, nullLiteral).eval(EmptyRow))
210200
}
211201
}
202+
203+
test("remove unnecessary if when the outputs are boolean type") {
204+
assertEquivalent(
205+
If(IsNotNull(UnresolvedAttribute("a")), TrueLiteral, FalseLiteral),
206+
IsNotNull(UnresolvedAttribute("a")))
207+
assertEquivalent(
208+
If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, TrueLiteral),
209+
IsNull(UnresolvedAttribute("a")))
210+
211+
assertEquivalent(
212+
If(GreaterThan(Rand(0), UnresolvedAttribute("a")), TrueLiteral, FalseLiteral),
213+
If(GreaterThan(Rand(0), UnresolvedAttribute("a")), TrueLiteral, FalseLiteral))
214+
assertEquivalent(
215+
If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral),
216+
If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral))
217+
}
212218
}

0 commit comments

Comments
 (0)