@@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
21
21
import org .apache .spark .sql .catalyst .analysis .UnresolvedAttribute
22
22
import org .apache .spark .sql .catalyst .dsl .expressions ._
23
23
import org .apache .spark .sql .catalyst .dsl .plans ._
24
- import org .apache .spark .sql .catalyst .expressions .{And , ArrayExists , ArrayFilter , ArrayTransform , CaseWhen , Expression , GreaterThan , If , LambdaFunction , Literal , MapFilter , NamedExpression , Or , UnresolvedNamedLambdaVariable }
24
+ import org .apache .spark .sql .catalyst .expressions .{And , ArrayExists , ArrayFilter , ArrayTransform , CaseWhen , Expression , GreaterThan , If , LambdaFunction , LessThanOrEqual , Literal , MapFilter , NamedExpression , Or , UnresolvedNamedLambdaVariable }
25
25
import org .apache .spark .sql .catalyst .expressions .Literal .{FalseLiteral , TrueLiteral }
26
26
import org .apache .spark .sql .catalyst .plans .{Inner , PlanTest }
27
27
import org .apache .spark .sql .catalyst .plans .logical .{DeleteFromTable , LocalRelation , LogicalPlan , UpdateTable }
@@ -236,12 +236,13 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
236
236
Literal (2 ) === nestedCaseWhen,
237
237
TrueLiteral ,
238
238
FalseLiteral )
239
- val branches = Seq ((UnresolvedAttribute (" i" ) > Literal (10 )) -> branchValue)
240
- val condition = CaseWhen (branches)
241
- testFilter(originalCond = condition, expectedCond = condition)
242
- testJoin(originalCond = condition, expectedCond = condition)
243
- testDelete(originalCond = condition, expectedCond = condition)
244
- testUpdate(originalCond = condition, expectedCond = condition)
239
+ val condition = CaseWhen (Seq ((UnresolvedAttribute (" i" ) > Literal (10 )) -> branchValue))
240
+ val expectedCond =
241
+ CaseWhen (Seq ((UnresolvedAttribute (" i" ) > Literal (10 )) -> (Literal (2 ) === nestedCaseWhen)))
242
+ testFilter(originalCond = condition, expectedCond = expectedCond)
243
+ testJoin(originalCond = condition, expectedCond = expectedCond)
244
+ testDelete(originalCond = condition, expectedCond = expectedCond)
245
+ testUpdate(originalCond = condition, expectedCond = expectedCond)
245
246
}
246
247
247
248
test(" inability to replace null in non-boolean branches of If inside another If" ) {
@@ -252,10 +253,14 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
252
253
Literal (3 )),
253
254
TrueLiteral ,
254
255
FalseLiteral )
255
- testFilter(originalCond = condition, expectedCond = condition)
256
- testJoin(originalCond = condition, expectedCond = condition)
257
- testDelete(originalCond = condition, expectedCond = condition)
258
- testUpdate(originalCond = condition, expectedCond = condition)
256
+ val expectedCond = Literal (5 ) > If (
257
+ UnresolvedAttribute (" i" ) === Literal (15 ),
258
+ Literal (null , IntegerType ),
259
+ Literal (3 ))
260
+ testFilter(originalCond = condition, expectedCond = expectedCond)
261
+ testJoin(originalCond = condition, expectedCond = expectedCond)
262
+ testDelete(originalCond = condition, expectedCond = expectedCond)
263
+ testUpdate(originalCond = condition, expectedCond = expectedCond)
259
264
}
260
265
261
266
test(" replace null in If used as a join condition" ) {
@@ -405,9 +410,9 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
405
410
val lambda1 = LambdaFunction (
406
411
function = If (cond, Literal (null , BooleanType ), TrueLiteral ),
407
412
arguments = lambdaArgs)
408
- // the optimized lambda body is: if(arg > 0, false, true)
413
+ // the optimized lambda body is: if(arg > 0, false, true) => arg <= 0
409
414
val lambda2 = LambdaFunction (
410
- function = If (cond, FalseLiteral , TrueLiteral ),
415
+ function = LessThanOrEqual (condArg, Literal ( 0 ) ),
411
416
arguments = lambdaArgs)
412
417
testProjection(
413
418
originalExpr = createExpr(argument, lambda1) as ' x ,
0 commit comments