Skip to content

Commit 197bc29

Browse files
committed
[SPARK-26366][SQL][BACKPORT-2.3] ReplaceExceptWithFilter should consider NULL as False
In `ReplaceExceptWithFilter` we do not consider properly the case in which the condition returns NULL. Indeed, in that case, since negating NULL still returns NULL, so it is not true the assumption that negating the condition returns all the rows which didn't satisfy it, rows returning NULL may not be returned. This happens when constraints inferred by `InferFiltersFromConstraints` are not enough, as it happens with `OR` conditions. The rule had also problems with non-deterministic conditions: in such a scenario, this rule would change the probability of the output. The PR fixes these problem by: - returning False for the condition when it is Null (in this way we do return all the rows which didn't satisfy it); - avoiding any transformation when the condition is non-deterministic. added UTs Closes #23315 from mgaido91/SPARK-26366. Authored-by: Marco Gaido <marcogaido91@gmail.com> Signed-off-by: gatorsmile <gatorsmile@gmail.com>
1 parent 832812e commit 197bc29

File tree

4 files changed

+101
-24
lines changed

4 files changed

+101
-24
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
3636
* Note:
3737
* Before flipping the filter condition of the right node, we should:
3838
* 1. Combine all it's [[Filter]].
39-
* 2. Apply InferFiltersFromConstraints rule (to take into account of NULL values in the condition).
39+
* 2. Update the attribute references to the left node;
40+
* 3. Add a Coalesce(condition, False) (to take into account of NULL values in the condition).
4041
*/
4142
object ReplaceExceptWithFilter extends Rule[LogicalPlan] {
4243

@@ -47,23 +48,28 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] {
4748

4849
plan.transform {
4950
case e @ Except(left, right) if isEligible(left, right) =>
50-
val newCondition = transformCondition(left, skipProject(right))
51-
newCondition.map { c =>
52-
Distinct(Filter(Not(c), left))
53-
}.getOrElse {
51+
val filterCondition = combineFilters(skipProject(right)).asInstanceOf[Filter].condition
52+
if (filterCondition.deterministic) {
53+
transformCondition(left, filterCondition).map { c =>
54+
Distinct(Filter(Not(c), left))
55+
}.getOrElse {
56+
e
57+
}
58+
} else {
5459
e
5560
}
5661
}
5762
}
5863

59-
private def transformCondition(left: LogicalPlan, right: LogicalPlan): Option[Expression] = {
60-
val filterCondition =
61-
InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition
62-
63-
val attributeNameMap: Map[String, Attribute] = left.output.map(x => (x.name, x)).toMap
64-
65-
if (filterCondition.references.forall(r => attributeNameMap.contains(r.name))) {
66-
Some(filterCondition.transform { case a: AttributeReference => attributeNameMap(a.name) })
64+
private def transformCondition(plan: LogicalPlan, condition: Expression): Option[Expression] = {
65+
val attributeNameMap: Map[String, Attribute] = plan.output.map(x => (x.name, x)).toMap
66+
if (condition.references.forall(r => attributeNameMap.contains(r.name))) {
67+
val rewrittenCondition = condition.transform {
68+
case a: AttributeReference => attributeNameMap(a.name)
69+
}
70+
// We need to consider as False when the condition is NULL, otherwise we do not return those
71+
// rows containing NULL which are instead filtered in the Except right plan
72+
Some(Coalesce(Seq(rewrittenCondition, Literal.FalseLiteral)))
6773
} else {
6874
None
6975
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer
2020
import org.apache.spark.sql.Row
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.dsl.plans._
23-
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not}
23+
import org.apache.spark.sql.catalyst.expressions.{Alias, Coalesce, If, Literal, Not}
2424
import org.apache.spark.sql.catalyst.expressions.aggregate.First
2525
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
2626
import org.apache.spark.sql.catalyst.plans.logical._
2727
import org.apache.spark.sql.catalyst.rules.RuleExecutor
28+
import org.apache.spark.sql.types.BooleanType
2829

2930
class ReplaceOperatorSuite extends PlanTest {
3031

@@ -65,8 +66,7 @@ class ReplaceOperatorSuite extends PlanTest {
6566

6667
val correctAnswer =
6768
Aggregate(table1.output, table1.output,
68-
Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
69-
(attributeA >= 2 && attributeB < 1)),
69+
Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
7070
Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze
7171

7272
comparePlans(optimized, correctAnswer)
@@ -84,8 +84,8 @@ class ReplaceOperatorSuite extends PlanTest {
8484

8585
val correctAnswer =
8686
Aggregate(table1.output, table1.output,
87-
Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
88-
(attributeA >= 2 && attributeB < 1)), table1)).analyze
87+
Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
88+
table1)).analyze
8989

9090
comparePlans(optimized, correctAnswer)
9191
}
@@ -104,8 +104,7 @@ class ReplaceOperatorSuite extends PlanTest {
104104

105105
val correctAnswer =
106106
Aggregate(table1.output, table1.output,
107-
Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
108-
(attributeA >= 2 && attributeB < 1)),
107+
Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
109108
Project(Seq(attributeA, attributeB), table1))).analyze
110109

111110
comparePlans(optimized, correctAnswer)
@@ -125,8 +124,7 @@ class ReplaceOperatorSuite extends PlanTest {
125124

126125
val correctAnswer =
127126
Aggregate(table1.output, table1.output,
128-
Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
129-
(attributeA >= 2 && attributeB < 1)),
127+
Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
130128
Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze
131129

132130
comparePlans(optimized, correctAnswer)
@@ -146,8 +144,7 @@ class ReplaceOperatorSuite extends PlanTest {
146144

147145
val correctAnswer =
148146
Aggregate(table1.output, table1.output,
149-
Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
150-
(attributeA === 1 && attributeB === 2)),
147+
Filter(Not(Coalesce(Seq(attributeA === 1 && attributeB === 2, Literal.FalseLiteral))),
151148
Project(Seq(attributeA, attributeB),
152149
Filter(attributeB < 1, Filter(attributeA >= 2, table1))))).analyze
153150

@@ -229,4 +226,29 @@ class ReplaceOperatorSuite extends PlanTest {
229226

230227
comparePlans(optimized, query)
231228
}
229+
230+
test("SPARK-26366: ReplaceExceptWithFilter should handle properly NULL") {
231+
val basePlan = LocalRelation(Seq('a.int, 'b.int))
232+
val otherPlan = basePlan.where('a.in(1, 2) || 'b.in())
233+
val except = Except(basePlan, otherPlan, false)
234+
val result = OptimizeIn(Optimize.execute(except.analyze))
235+
val correctAnswer = Aggregate(basePlan.output, basePlan.output,
236+
Filter(!Coalesce(Seq(
237+
'a.in(1, 2) || If('b.isNotNull, Literal.FalseLiteral, Literal(null, BooleanType)),
238+
Literal.FalseLiteral)),
239+
basePlan)).analyze
240+
comparePlans(result, correctAnswer)
241+
}
242+
243+
test("SPARK-26366: ReplaceExceptWithFilter should not transform non-detrministic") {
244+
val basePlan = LocalRelation(Seq('a.int, 'b.int))
245+
val otherPlan = basePlan.where('a > rand(1L))
246+
val except = Except(basePlan, otherPlan, false)
247+
val result = Optimize.execute(except.analyze)
248+
val condition = basePlan.output.zip(otherPlan.output).map { case (a1, a2) =>
249+
a1 <=> a2 }.reduce( _ && _)
250+
val correctAnswer = Aggregate(basePlan.output, otherPlan.output,
251+
Join(basePlan, otherPlan, LeftAnti, Option(condition))).analyze
252+
comparePlans(result, correctAnswer)
253+
}
232254
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,6 +1467,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
14671467
checkAnswer(df.groupBy(col("a")).agg(first(col("b"))),
14681468
Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", BigDecimal.valueOf(1.1111))))
14691469
}
1470+
1471+
test("SPARK-26366: return nulls which are not filtered in except") {
1472+
val inputDF = sqlContext.createDataFrame(
1473+
sparkContext.parallelize(Seq(Row("0", "a"), Row("1", null))),
1474+
StructType(Seq(
1475+
StructField("a", StringType, nullable = true),
1476+
StructField("b", StringType, nullable = true))))
1477+
1478+
val exceptDF = inputDF.filter(col("a").isin("0") or col("b") > "c")
1479+
checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null)))
1480+
}
14701481
}
14711482

14721483
case class TestDataUnion(x: Int, y: Int, z: Int)

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2831,6 +2831,44 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
28312831
checkAnswer(sql("select 26393499451 / (1e6 * 1000)"), Row(BigDecimal("26.3934994510000")))
28322832
}
28332833
}
2834+
2835+
test("SPARK-26366: verify ReplaceExceptWithFilter") {
2836+
Seq(true, false).foreach { enabled =>
2837+
withSQLConf(SQLConf.REPLACE_EXCEPT_WITH_FILTER.key -> enabled.toString) {
2838+
val df = spark.createDataFrame(
2839+
sparkContext.parallelize(Seq(Row(0, 3, 5),
2840+
Row(0, 3, null),
2841+
Row(null, 3, 5),
2842+
Row(0, null, 5),
2843+
Row(0, null, null),
2844+
Row(null, null, 5),
2845+
Row(null, 3, null),
2846+
Row(null, null, null))),
2847+
StructType(Seq(StructField("c1", IntegerType),
2848+
StructField("c2", IntegerType),
2849+
StructField("c3", IntegerType))))
2850+
val where = "c2 >= 3 OR c1 >= 0"
2851+
val whereNullSafe =
2852+
"""
2853+
|(c2 IS NOT NULL AND c2 >= 3)
2854+
|OR (c1 IS NOT NULL AND c1 >= 0)
2855+
""".stripMargin
2856+
2857+
val df_a = df.filter(where)
2858+
val df_b = df.filter(whereNullSafe)
2859+
checkAnswer(df.except(df_a), df.except(df_b))
2860+
2861+
val whereWithIn = "c2 >= 3 OR c1 in (2)"
2862+
val whereWithInNullSafe =
2863+
"""
2864+
|(c2 IS NOT NULL AND c2 >= 3)
2865+
""".stripMargin
2866+
val dfIn_a = df.filter(whereWithIn)
2867+
val dfIn_b = df.filter(whereWithInNullSafe)
2868+
checkAnswer(df.except(dfIn_a), df.except(dfIn_b))
2869+
}
2870+
}
2871+
}
28342872
}
28352873

28362874
case class Foo(bar: Option[String])

0 commit comments

Comments
 (0)