Skip to content

[SPARK-32801][SQL] Make InferFiltersFromConstraints take into account EqualNullSafe #29650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,25 @@ trait ConstraintHelper {
*/
def inferAdditionalConstraints(constraints: ExpressionSet): ExpressionSet = {
var inferredConstraints = ExpressionSet()
// IsNotNull should be constructed by `constructIsNotNullConstraints`.
val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull])
predicates.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
val candidateConstraints = predicates - eq
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
case eq @ EqualTo(l @ Cast(_: Attribute, _, _), r: Attribute) =>
inferredConstraints ++= replaceConstraints(predicates - eq, r, l)
case eq @ EqualTo(l: Attribute, r @ Cast(_: Attribute, _, _)) =>
inferredConstraints ++= replaceConstraints(predicates - eq, l, r)
case _ => // No inference
var prevSize = -1
while (inferredConstraints.size > prevSize) {
prevSize = inferredConstraints.size
val predicates = (constraints ++ inferredConstraints)
// IsNotNull should be constructed by `constructIsNotNullConstraints`.
.filterNot(_.isInstanceOf[IsNotNull])
// Non deterministic expressions are all not equal and would cause OOM
.filter(_.deterministic)
predicates.foreach {
case eq @ Equality(l: Attribute, r: Attribute) =>
val candidateConstraints = predicates - eq
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
case eq @ Equality(l @ Cast(_: Attribute, _, _), r: Attribute) =>
inferredConstraints ++= replaceConstraints(predicates - eq, r, l)
case eq @ Equality(l: Attribute, r @ Cast(_: Attribute, _, _)) =>
inferredConstraints ++= replaceConstraints(predicates - eq, l, r)
case _ => // No inference
}
}
inferredConstraints -- constraints
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ case class Join(
.union(ExpressionSet(splitConjunctivePredicates(condition.get)))
case LeftSemi if condition.isDefined =>
left.constraints
.union(right.constraints)
.union(ExpressionSet(splitConjunctivePredicates(condition.get)))
case j: ExistenceJoin =>
left.constraints
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,21 @@ import org.apache.spark.sql.types.{IntegerType, LongType}
class InferFiltersFromConstraintsSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("InferAndPushDownFilters", FixedPoint(100),
PushPredicateThroughJoin,
PushPredicateThroughNonJoin,
InferFiltersFromConstraints,
CombineFilters,
SimplifyBinaryComparison,
BooleanSimplification,
PruneFilters) :: Nil
val operatorOptimizationRuleSet = Seq(
PushDownPredicates,
BooleanSimplification,
SimplifyBinaryComparison,
PruneFilters)

val batches = Batch(
"Operator Optimization before Inferring Filters",
FixedPoint(100),
operatorOptimizationRuleSet: _*) ::
Batch("Infer Filters", Once, InferFiltersFromConstraints) ::
Batch(
"Operator Optimization after Inferring Filters",
FixedPoint(100),
operatorOptimizationRuleSet: _*) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
Expand Down Expand Up @@ -316,4 +322,75 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
condition)
}
}

test(
"SPARK-32801: Single inner join with EqualNullSafe condition: " +
"filter out values on either side on equi-join keys") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val originalQuery =
x.join(y, condition = Some(("x.a".attr <=> "y.a".attr) && ("x.a".attr > 5))).analyze
val left = x.where(IsNotNull('a) && "x.a".attr > 5)
val right = y.where(IsNotNull('a) && "y.a".attr > 5)
val correctAnswer = left.join(right, condition = Some("x.a".attr <=> "y.a".attr)).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-32801: Infer all constraints from a chain of filters") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val originalQuery = x
.where("x.a".attr === "x.b".attr)
.join(y, condition = Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr))
.analyze
val left = x.where(IsNotNull('a) && IsNotNull('b) && "x.a".attr === "x.b".attr)
val right = y.where(IsNotNull('a) && IsNotNull('b) && "y.a".attr === "y.b".attr)
val correctAnswer = left
.join(right, condition = Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-32801: Infer from right side of left semi join") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val z = testRelation.subquery('z)
val originalQuery = x
.join(
y.join(
z.where("z.a".attr > 1),
condition = Some("y.a".attr === "z.a".attr),
joinType = LeftSemi),
condition = Some("x.a".attr === "y.a".attr))
.analyze
val correctX = x.where(IsNotNull('a) && "x.a".attr > 1)
val correctY = y.where(IsNotNull('a) && "y.a".attr > 1)
val correctZ = z.where(IsNotNull('a) && "z.a".attr > 1)
val correctAnswer = correctX
.join(
correctY.join(correctZ, condition = Some("y.a".attr === "z.a".attr), joinType = LeftSemi),
condition = Some("x.a".attr === "y.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-32801: Non-deterministic filters do not introduce an infinite loop") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val originalQuery = x
.join(y, condition = Some("x.a".attr === "y.a".attr))
.where(rand(0) === "x.a".attr)
.analyze
val left = x.where(IsNotNull('a))
val right = y.where(IsNotNull('a))
val correctAnswer = left
.join(right, condition = Some("x.a".attr === "y.a".attr))
.where(rand(0) === "x.a".attr)
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}
}