Skip to content

[SPARK-26366][SQL][BACKPORT-2.3] ReplaceExceptWithFilter should consider NULL as False #23372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
* Note:
* Before flipping the filter condition of the right node, we should:
* 1. Combine all it's [[Filter]].
* 2. Apply InferFiltersFromConstraints rule (to take into account of NULL values in the condition).
* 2. Update the attribute references to the left node;
* 3. Add a Coalesce(condition, False) (to take into account of NULL values in the condition).
*/
object ReplaceExceptWithFilter extends Rule[LogicalPlan] {

Expand All @@ -47,23 +48,28 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] {

plan.transform {
case e @ Except(left, right) if isEligible(left, right) =>
val newCondition = transformCondition(left, skipProject(right))
newCondition.map { c =>
Distinct(Filter(Not(c), left))
}.getOrElse {
val filterCondition = combineFilters(skipProject(right)).asInstanceOf[Filter].condition
if (filterCondition.deterministic) {
transformCondition(left, filterCondition).map { c =>
Distinct(Filter(Not(c), left))
}.getOrElse {
e
}
} else {
e
}
}
}

private def transformCondition(left: LogicalPlan, right: LogicalPlan): Option[Expression] = {
val filterCondition =
InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition

val attributeNameMap: Map[String, Attribute] = left.output.map(x => (x.name, x)).toMap

if (filterCondition.references.forall(r => attributeNameMap.contains(r.name))) {
Some(filterCondition.transform { case a: AttributeReference => attributeNameMap(a.name) })
private def transformCondition(plan: LogicalPlan, condition: Expression): Option[Expression] = {
val attributeNameMap: Map[String, Attribute] = plan.output.map(x => (x.name, x)).toMap
if (condition.references.forall(r => attributeNameMap.contains(r.name))) {
val rewrittenCondition = condition.transform {
case a: AttributeReference => attributeNameMap(a.name)
}
// We need to consider as False when the condition is NULL, otherwise we do not return those
// rows containing NULL which are instead filtered in the Except right plan
Some(Coalesce(Seq(rewrittenCondition, Literal.FalseLiteral)))
} else {
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not}
import org.apache.spark.sql.catalyst.expressions.{Alias, Coalesce, If, Literal, Not}
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.BooleanType

class ReplaceOperatorSuite extends PlanTest {

Expand Down Expand Up @@ -65,8 +66,7 @@ class ReplaceOperatorSuite extends PlanTest {

val correctAnswer =
Aggregate(table1.output, table1.output,
Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
(attributeA >= 2 && attributeB < 1)),
Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze

comparePlans(optimized, correctAnswer)
Expand All @@ -84,8 +84,8 @@ class ReplaceOperatorSuite extends PlanTest {

val correctAnswer =
Aggregate(table1.output, table1.output,
Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
(attributeA >= 2 && attributeB < 1)), table1)).analyze
Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
table1)).analyze

comparePlans(optimized, correctAnswer)
}
Expand All @@ -104,8 +104,7 @@ class ReplaceOperatorSuite extends PlanTest {

val correctAnswer =
Aggregate(table1.output, table1.output,
Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
(attributeA >= 2 && attributeB < 1)),
Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
Project(Seq(attributeA, attributeB), table1))).analyze

comparePlans(optimized, correctAnswer)
Expand All @@ -125,8 +124,7 @@ class ReplaceOperatorSuite extends PlanTest {

val correctAnswer =
Aggregate(table1.output, table1.output,
Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
(attributeA >= 2 && attributeB < 1)),
Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze

comparePlans(optimized, correctAnswer)
Expand All @@ -146,8 +144,7 @@ class ReplaceOperatorSuite extends PlanTest {

val correctAnswer =
Aggregate(table1.output, table1.output,
Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
(attributeA === 1 && attributeB === 2)),
Filter(Not(Coalesce(Seq(attributeA === 1 && attributeB === 2, Literal.FalseLiteral))),
Project(Seq(attributeA, attributeB),
Filter(attributeB < 1, Filter(attributeA >= 2, table1))))).analyze

Expand Down Expand Up @@ -229,4 +226,27 @@ class ReplaceOperatorSuite extends PlanTest {

comparePlans(optimized, query)
}

test("SPARK-26366: ReplaceExceptWithFilter should handle properly NULL") {
val basePlan = LocalRelation(Seq('a.int, 'b.int))
val otherPlan = basePlan.where('a.in(1, 2) || 'b.in())
val except = Except(basePlan, otherPlan)
val result = OptimizeIn(Optimize.execute(except.analyze))
val correctAnswer = Aggregate(basePlan.output, basePlan.output,
Filter(!Coalesce(Seq('a.in(1, 2) || 'b.in(), Literal.FalseLiteral)),
basePlan)).analyze
comparePlans(result, correctAnswer)
}

test("SPARK-26366: ReplaceExceptWithFilter should not transform non-detrministic") {
val basePlan = LocalRelation(Seq('a.int, 'b.int))
val otherPlan = basePlan.where('a > rand(1L))
val except = Except(basePlan, otherPlan)
val result = Optimize.execute(except.analyze)
val condition = basePlan.output.zip(otherPlan.output).map { case (a1, a2) =>
a1 <=> a2 }.reduce( _ && _)
val correctAnswer = Aggregate(basePlan.output, otherPlan.output,
Join(basePlan, otherPlan, LeftAnti, Option(condition))).analyze
comparePlans(result, correctAnswer)
}
}
11 changes: 11 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1467,6 +1467,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.groupBy(col("a")).agg(first(col("b"))),
Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", BigDecimal.valueOf(1.1111))))
}

test("SPARK-26366: return nulls which are not filtered in except") {
val inputDF = sqlContext.createDataFrame(
sparkContext.parallelize(Seq(Row("0", "a"), Row("1", null))),
StructType(Seq(
StructField("a", StringType, nullable = true),
StructField("b", StringType, nullable = true))))

val exceptDF = inputDF.filter(col("a").isin("0") or col("b") > "c")
checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null)))
}
}

case class TestDataUnion(x: Int, y: Int, z: Int)
Expand Down
38 changes: 38 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2831,6 +2831,44 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(sql("select 26393499451 / (1e6 * 1000)"), Row(BigDecimal("26.3934994510000")))
}
}

test("SPARK-26366: verify ReplaceExceptWithFilter") {
Seq(true, false).foreach { enabled =>
withSQLConf(SQLConf.REPLACE_EXCEPT_WITH_FILTER.key -> enabled.toString) {
val df = spark.createDataFrame(
sparkContext.parallelize(Seq(Row(0, 3, 5),
Row(0, 3, null),
Row(null, 3, 5),
Row(0, null, 5),
Row(0, null, null),
Row(null, null, 5),
Row(null, 3, null),
Row(null, null, null))),
StructType(Seq(StructField("c1", IntegerType),
StructField("c2", IntegerType),
StructField("c3", IntegerType))))
val where = "c2 >= 3 OR c1 >= 0"
val whereNullSafe =
"""
|(c2 IS NOT NULL AND c2 >= 3)
|OR (c1 IS NOT NULL AND c1 >= 0)
""".stripMargin

val df_a = df.filter(where)
val df_b = df.filter(whereNullSafe)
checkAnswer(df.except(df_a), df.except(df_b))

val whereWithIn = "c2 >= 3 OR c1 in (2)"
val whereWithInNullSafe =
"""
|(c2 IS NOT NULL AND c2 >= 3)
""".stripMargin
val dfIn_a = df.filter(whereWithIn)
val dfIn_b = df.filter(whereWithInNullSafe)
checkAnswer(df.except(dfIn_a), df.except(dfIn_b))
}
}
}
}

case class Foo(bar: Option[String])