Skip to content

Commit 3c02357

Browse files
jiangxb1987hvanhovell
authored andcommitted
[SPARK-17733][SQL] InferFiltersFromConstraints rule never terminates for query
## What changes were proposed in this pull request? The function `QueryPlan.inferAdditionalConstraints` and `UnaryNode.getAliasedConstraints` can produce a non-converging set of constraints for recursive functions. For instance, if we have two constraints of the form(where a is an alias): `a = b, a = f(b, c)` Applying both these rules in the next iteration would infer: `f(b, c) = f(f(b, c), c)` This process repeated, the iteration won't converge and the set of constraints will grow larger and larger until OOM. ~~To fix this problem, we collect alias from expressions and skip infer constraints if we are to transform an `Expression` to another which contains it.~~ To fix this problem, we apply additional check in `inferAdditionalConstraints`, when it's possible to generate recursive constraints, we skip generate that. ## How was this patch tested? Add new testcase in `SQLQuerySuite`/`InferFiltersFromConstraintsSuite`. Author: jiangxingbo <jiangxb1987@gmail.com> Closes apache#15319 from jiangxb1987/constraints.
1 parent 402205d commit 3c02357

File tree

4 files changed

+191
-14
lines changed

4 files changed

+191
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,26 +68,104 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
6868
case _ => Seq.empty[Attribute]
6969
}
7070

71+
// Collect aliases from expressions, so we may avoid producing recursive constraints.
72+
private lazy val aliasMap = AttributeMap(
73+
(expressions ++ children.flatMap(_.expressions)).collect {
74+
case a: Alias => (a.toAttribute, a.child)
75+
})
76+
7177
/**
7278
* Infers an additional set of constraints from a given set of equality constraints.
7379
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
74-
* additional constraint of the form `b = 5`
80+
* additional constraint of the form `b = 5`.
81+
*
82+
* [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)`
83+
* as they are often useless and can lead to a non-converging set of constraints.
7584
*/
7685
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
86+
val constraintClasses = generateEquivalentConstraintClasses(constraints)
87+
7788
var inferredConstraints = Set.empty[Expression]
7889
constraints.foreach {
7990
case eq @ EqualTo(l: Attribute, r: Attribute) =>
80-
inferredConstraints ++= (constraints - eq).map(_ transform {
81-
case a: Attribute if a.semanticEquals(l) => r
91+
val candidateConstraints = constraints - eq
92+
inferredConstraints ++= candidateConstraints.map(_ transform {
93+
case a: Attribute if a.semanticEquals(l) &&
94+
!isRecursiveDeduction(r, constraintClasses) => r
8295
})
83-
inferredConstraints ++= (constraints - eq).map(_ transform {
84-
case a: Attribute if a.semanticEquals(r) => l
96+
inferredConstraints ++= candidateConstraints.map(_ transform {
97+
case a: Attribute if a.semanticEquals(r) &&
98+
!isRecursiveDeduction(l, constraintClasses) => l
8599
})
86100
case _ => // No inference
87101
}
88102
inferredConstraints -- constraints
89103
}
90104

105+
/*
106+
* Generate a sequence of expression sets from constraints, where each set stores an equivalence
107+
* class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following
108+
* expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal
109+
* to an selected attribute.
110+
*/
111+
private def generateEquivalentConstraintClasses(
112+
constraints: Set[Expression]): Seq[Set[Expression]] = {
113+
var constraintClasses = Seq.empty[Set[Expression]]
114+
constraints.foreach {
115+
case eq @ EqualTo(l: Attribute, r: Attribute) =>
116+
// Transform [[Alias]] to its child.
117+
val left = aliasMap.getOrElse(l, l)
118+
val right = aliasMap.getOrElse(r, r)
119+
// Get the expression set for an equivalence constraint class.
120+
val leftConstraintClass = getConstraintClass(left, constraintClasses)
121+
val rightConstraintClass = getConstraintClass(right, constraintClasses)
122+
if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) {
123+
// Combine the two sets.
124+
constraintClasses = constraintClasses
125+
.diff(leftConstraintClass :: rightConstraintClass :: Nil) :+
126+
(leftConstraintClass ++ rightConstraintClass)
127+
} else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty
128+
// Update equivalence class of `left` expression.
129+
constraintClasses = constraintClasses
130+
.diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right)
131+
} else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty
132+
// Update equivalence class of `right` expression.
133+
constraintClasses = constraintClasses
134+
.diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left)
135+
} else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty
136+
// Create new equivalence constraint class since neither expression presents
137+
// in any classes.
138+
constraintClasses = constraintClasses :+ Set(left, right)
139+
}
140+
case _ => // Skip
141+
}
142+
143+
constraintClasses
144+
}
145+
146+
/*
147+
* Get all expressions equivalent to the selected expression.
148+
*/
149+
private def getConstraintClass(
150+
expr: Expression,
151+
constraintClasses: Seq[Set[Expression]]): Set[Expression] =
152+
constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression])
153+
154+
/*
155+
* Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it
156+
* has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function.
157+
* Here we first get all expressions equal to `attr` and then check whether at least one of them
158+
* is a child of the referenced expression.
159+
*/
160+
private def isRecursiveDeduction(
161+
attr: Attribute,
162+
constraintClasses: Seq[Set[Expression]]): Boolean = {
163+
val expr = aliasMap.getOrElse(attr, attr)
164+
getConstraintClass(expr, constraintClasses).exists { e =>
165+
expr.children.exists(_.semanticEquals(e))
166+
}
167+
}
168+
91169
/**
92170
* An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
93171
* example, if this set contains the expression `a = 2` then that expression is guaranteed to

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,12 @@ import org.apache.spark.sql.catalyst.rules._
2727
class InferFiltersFromConstraintsSuite extends PlanTest {
2828

2929
object Optimize extends RuleExecutor[LogicalPlan] {
30-
val batches = Batch("InferFilters", FixedPoint(5), InferFiltersFromConstraints) ::
31-
Batch("PredicatePushdown", FixedPoint(5), PushPredicateThroughJoin) ::
32-
Batch("CombineFilters", FixedPoint(5), CombineFilters) :: Nil
30+
val batches =
31+
Batch("InferAndPushDownFilters", FixedPoint(100),
32+
PushPredicateThroughJoin,
33+
PushDownPredicate,
34+
InferFiltersFromConstraints,
35+
CombineFilters) :: Nil
3336
}
3437

3538
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
@@ -120,4 +123,82 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
120123
val optimized = Optimize.execute(originalQuery)
121124
comparePlans(optimized, correctAnswer)
122125
}
126+
127+
test("inner join with alias: alias contains multiple attributes") {
128+
val t1 = testRelation.subquery('t1)
129+
val t2 = testRelation.subquery('t2)
130+
131+
val originalQuery = t1.select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t")
132+
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
133+
.analyze
134+
val correctAnswer = t1
135+
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b)))
136+
.select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t")
137+
.join(t2.where(IsNotNull('a)), Inner,
138+
Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
139+
.analyze
140+
val optimized = Optimize.execute(originalQuery)
141+
comparePlans(optimized, correctAnswer)
142+
}
143+
144+
test("inner join with alias: alias contains single attributes") {
145+
val t1 = testRelation.subquery('t1)
146+
val t2 = testRelation.subquery('t2)
147+
148+
val originalQuery = t1.select('a, 'b.as('d)).as("t")
149+
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
150+
.analyze
151+
val correctAnswer = t1
152+
.where(IsNotNull('a) && IsNotNull('b) && 'a <=> 'a && 'b <=> 'b &&'a === 'b)
153+
.select('a, 'b.as('d)).as("t")
154+
.join(t2.where(IsNotNull('a) && 'a <=> 'a), Inner,
155+
Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
156+
.analyze
157+
val optimized = Optimize.execute(originalQuery)
158+
comparePlans(optimized, correctAnswer)
159+
}
160+
161+
test("inner join with alias: don't generate constraints for recursive functions") {
162+
val t1 = testRelation.subquery('t1)
163+
val t2 = testRelation.subquery('t2)
164+
165+
val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
166+
.join(t2, Inner,
167+
Some("t.a".attr === "t2.a".attr
168+
&& "t.d".attr === "t2.a".attr
169+
&& "t.int_col".attr === "t2.a".attr))
170+
.analyze
171+
val correctAnswer = t1
172+
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
173+
&& 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a
174+
&& Coalesce(Seq('a, 'a)) <=> 'b && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))
175+
&& 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b))
176+
&& Coalesce(Seq('a, 'b)) <=> Coalesce(Seq('b, 'b)) && Coalesce(Seq('a, 'b)) === 'b
177+
&& IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b)))
178+
&& 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b))
179+
&& Coalesce(Seq('b, 'b)) <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b)
180+
.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
181+
.join(t2
182+
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
183+
&& 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a
184+
&& Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))), Inner,
185+
Some("t.a".attr === "t2.a".attr
186+
&& "t.d".attr === "t2.a".attr
187+
&& "t.int_col".attr === "t2.a".attr
188+
&& Coalesce(Seq("t.d".attr, "t.d".attr)) <=> "t.int_col".attr))
189+
.analyze
190+
val optimized = Optimize.execute(originalQuery)
191+
comparePlans(optimized, correctAnswer)
192+
}
193+
194+
test("generate correct filters for alias that don't produce recursive constraints") {
195+
val t1 = testRelation.subquery('t1)
196+
197+
val originalQuery = t1.select('a.as('x), 'b.as('y)).where('x === 1 && 'x === 'y).analyze
198+
val correctAnswer =
199+
t1.where('a === 1 && 'b === 1 && 'a === 'b && IsNotNull('a) && IsNotNull('b))
200+
.select('a.as('x), 'b.as('y)).analyze
201+
val optimized = Optimize.execute(originalQuery)
202+
comparePlans(optimized, correctAnswer)
203+
}
123204
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans
2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
23-
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
23+
import org.apache.spark.sql.catalyst.plans.logical._
2424
import org.apache.spark.sql.catalyst.util._
2525

2626
/**
@@ -56,16 +56,37 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
5656
* ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
5757
* etc., will all now be equivalent.
5858
* - Sample the seed will replaced by 0L.
59+
* - Join conditions will be resorted by hashCode.
5960
*/
6061
private def normalizePlan(plan: LogicalPlan): LogicalPlan = {
6162
plan transform {
6263
case filter @ Filter(condition: Expression, child: LogicalPlan) =>
63-
Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child)
64+
Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode())
65+
.reduce(And), child)
6466
case sample: Sample =>
6567
sample.copy(seed = 0L)(true)
68+
case join @ Join(left, right, joinType, condition) if condition.isDefined =>
69+
val newCondition =
70+
splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode())
71+
.reduce(And)
72+
Join(left, right, joinType, Some(newCondition))
6673
}
6774
}
6875

76+
/**
77+
* Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be
78+
* equivalent:
79+
* 1. (a = b), (b = a);
80+
* 2. (a <=> b), (b <=> a).
81+
*/
82+
private def rewriteEqual(condition: Expression): Expression = condition match {
83+
case eq @ EqualTo(l: Expression, r: Expression) =>
84+
Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo)
85+
case eq @ EqualNullSafe(l: Expression, r: Expression) =>
86+
Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe)
87+
case _ => condition // Don't reorder.
88+
}
89+
6990
/** Fails the test if the two plans do not match */
7091
protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) {
7192
val normalized1 = normalizePlan(normalizeExprIds(plan1))

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,9 @@ package org.apache.spark.sql
1919

2020
import java.io.File
2121
import java.math.MathContext
22-
import java.sql.{Date, Timestamp}
22+
import java.sql.Timestamp
2323

2424
import org.apache.spark.{AccumulatorSuite, SparkException}
25-
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
26-
import org.apache.spark.sql.catalyst.expressions.SortOrder
27-
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
2825
import org.apache.spark.sql.catalyst.util.StringUtils
2926
import org.apache.spark.sql.execution.aggregate
3027
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}

0 commit comments

Comments
 (0)