Skip to content

Commit cd8860b

Browse files
committed
Improves compareConditions to handle more subtle cases
1 parent 1bf3258 commit cd8860b

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
package org.apache.spark.sql.catalyst.optimizer
1919

2020
import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
21-
import org.apache.spark.sql.catalyst.expressions.{Or, And, Literal, Expression}
21+
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.plans.logical._
2323
import org.apache.spark.sql.catalyst.plans.PlanTest
2424
import org.apache.spark.sql.catalyst.rules._
2525
import org.apache.spark.sql.catalyst.dsl.plans._
2626
import org.apache.spark.sql.catalyst.dsl.expressions._
2727

28-
class BooleanSimplificationSuite extends PlanTest {
28+
class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
2929

3030
object Optimize extends RuleExecutor[LogicalPlan] {
3131
val batches =
@@ -40,14 +40,21 @@ class BooleanSimplificationSuite extends PlanTest {
4040

4141
val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string)
4242

43+
// The `foldLeft` is required to handle cases like comparing `a && (b && c)` and `(a && b) && c`
4344
def compareConditions(e1: Expression, e2: Expression): Boolean = (e1, e2) match {
44-
case (And(l1, l2), And(r1, r2)) =>
45-
compareConditions(l1, r1) && compareConditions(l2, r2) ||
46-
compareConditions(l1, r2) && compareConditions(l2, r1)
47-
48-
case (Or(l1, l2), Or(r1, r2)) =>
49-
compareConditions(l1, r1) && compareConditions(l2, r2) ||
50-
compareConditions(l1, r2) && compareConditions(l2, r1)
45+
case (lhs: And, rhs: And) =>
46+
val lhsSet = splitConjunctivePredicates(lhs).toSet
47+
val rhsSet = splitConjunctivePredicates(rhs).toSet
48+
lhsSet.foldLeft(rhsSet) { (set, e) =>
49+
set.find(compareConditions(_, e)).map(set - _).getOrElse(set)
50+
}.isEmpty
51+
52+
case (lhs: Or, rhs: Or) =>
53+
val lhsSet = splitDisjunctivePredicates(lhs).toSet
54+
val rhsSet = splitDisjunctivePredicates(rhs).toSet
55+
lhsSet.foldLeft(rhsSet) { (set, e) =>
56+
set.find(compareConditions(_, e)).map(set - _).getOrElse(set)
57+
}.isEmpty
5158

5259
case (l, r) => l == r
5360
}

0 commit comments

Comments
 (0)