18
18
package org .apache .spark .sql .catalyst .optimizer
19
19
20
20
import org .apache .spark .sql .catalyst .analysis .EliminateAnalysisOperators
21
- import org .apache .spark .sql .catalyst .expressions .{ Or , And , Literal , Expression }
21
+ import org .apache .spark .sql .catalyst .expressions ._
22
22
import org .apache .spark .sql .catalyst .plans .logical ._
23
23
import org .apache .spark .sql .catalyst .plans .PlanTest
24
24
import org .apache .spark .sql .catalyst .rules ._
25
25
import org .apache .spark .sql .catalyst .dsl .plans ._
26
26
import org .apache .spark .sql .catalyst .dsl .expressions ._
27
27
28
- class BooleanSimplificationSuite extends PlanTest {
28
+ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
29
29
30
30
object Optimize extends RuleExecutor [LogicalPlan ] {
31
31
val batches =
@@ -40,14 +40,21 @@ class BooleanSimplificationSuite extends PlanTest {
40
40
41
41
val testRelation = LocalRelation (' a .int, ' b .int, ' c .int, ' d .string)
42
42
43
+ // The `foldLeft` is required to handle cases like comparing `a && (b && c)` and `(a && b) && c`
43
44
def compareConditions (e1 : Expression , e2 : Expression ): Boolean = (e1, e2) match {
44
- case (And (l1, l2), And (r1, r2)) =>
45
- compareConditions(l1, r1) && compareConditions(l2, r2) ||
46
- compareConditions(l1, r2) && compareConditions(l2, r1)
47
-
48
- case (Or (l1, l2), Or (r1, r2)) =>
49
- compareConditions(l1, r1) && compareConditions(l2, r2) ||
50
- compareConditions(l1, r2) && compareConditions(l2, r1)
45
+ case (lhs : And , rhs : And ) =>
46
+ val lhsSet = splitConjunctivePredicates(lhs).toSet
47
+ val rhsSet = splitConjunctivePredicates(rhs).toSet
48
+ lhsSet.foldLeft(rhsSet) { (set, e) =>
49
+ set.find(compareConditions(_, e)).map(set - _).getOrElse(set)
50
+ }.isEmpty
51
+
52
+ case (lhs : Or , rhs : Or ) =>
53
+ val lhsSet = splitDisjunctivePredicates(lhs).toSet
54
+ val rhsSet = splitDisjunctivePredicates(rhs).toSet
55
+ lhsSet.foldLeft(rhsSet) { (set, e) =>
56
+ set.find(compareConditions(_, e)).map(set - _).getOrElse(set)
57
+ }.isEmpty
51
58
52
59
case (l, r) => l == r
53
60
}
0 commit comments