Skip to content

Commit

Permalink
[SPARK-49164][SQL] Fix not NullSafeEqual in predicate of SQL query in…
Browse files Browse the repository at this point in the history
… JDBC Relation

### What changes were proposed in this pull request?
Changed the evaluation of <=> (NullEqualSafe) in V2ExpressionSQLBuilder.scala. If there was a predicate like
where not str <=> 'abc' and if there are null values in the table, null tables are not returned even though they should. The issue was in how <=> is translated to SQL string therefore producing wrong results.

### Why are the changes needed?
There is a bug currently.

### Does this PR introduce _any_ user-facing change?
Yes, since users can get different results now.

### How was this patch tested?
Existing tests were modified and one new test is added.

### Was this patch authored or co-authored using generative AI tooling?

Closes apache#47669 from PetarVasiljevic-DB/fix_not_null_safe_equal.

Authored-by: Petar Vasiljevic <petar.vasiljevic@databricks.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
PetarVasiljevic-DB authored and MaxGekk committed Aug 12, 2024
1 parent 2465cb0 commit cb7919b
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ protected String inputToSQL(Expression input) {

protected String visitBinaryComparison(String name, String l, String r) {
if (name.equals("<=>")) {
return "(" + l + " = " + r + ") OR (" + l + " IS NULL AND " + r + " IS NULL)";
return "((" + l + " IS NOT NULL AND " + r + " IS NOT NULL AND " + l + " = " + r + ") " +
"OR (" + l + " IS NULL AND " + r + " IS NULL))";
}
return l + " " + name + " " + r;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ class V2PredicateSuite extends SparkFunSuite {
val predicate2 = new Predicate("<=>", Array[Expression](ref("a"), LiteralValue(1, IntegerType)))
assert(predicate1.equals(predicate2))
assert(predicate1.references.map(_.describe()).toSeq == Seq("a"))
assert(predicate1.describe.equals("(a = 1) OR (a IS NULL AND 1 IS NULL)"))
assert(predicate1.describe.equals(
"((a IS NOT NULL AND 1 IS NOT NULL AND a = 1) OR (a IS NULL AND 1 IS NULL))"))

val v1Filter = EqualNullSafe("a", 1)
assert(v1Filter.toV2 == predicate1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -839,8 +839,8 @@ class JDBCSuite extends QueryTest with SharedSparkSession {
assert(doCompileFilter(IsNull(col1)) === """"col1" IS NULL""")
assert(doCompileFilter(IsNotNull(col1)) === """"col1" IS NOT NULL""")
assert(doCompileFilter(And(EqualNullSafe(col0, "abc"), EqualTo(col1, "def")))
=== """(("col0" = 'abc') OR ("col0" IS NULL AND 'abc' IS NULL))"""
+ """ AND ("col1" = 'def')""")
=== """((("col0" IS NOT NULL AND 'abc' IS NOT NULL AND "col0" = 'abc') """ +
"""OR ("col0" IS NULL AND 'abc' IS NULL))) AND ("col1" = 'def')""")
}
assert(doCompileFilter(EqualTo("col0.nested", 3)).isEmpty)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,15 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
"VALUES ('david', 10000, 1300, 0.13)").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"employee_bonus\" " +
"VALUES ('jen', 12000, 2400, 0.2)").executeUpdate()

conn.prepareStatement(
"CREATE TABLE \"test\".\"strings_with_nulls\" (str TEXT(32))").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"strings_with_nulls\" VALUES " +
"('abc')").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"strings_with_nulls\" VALUES " +
"('a a a')").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"strings_with_nulls\" VALUES " +
"(null)").executeUpdate()
}
h2Dialect.registerFunction("my_avg", IntegralAverage)
h2Dialect.registerFunction("my_strlen", StrLen(CharLength))
Expand Down Expand Up @@ -1769,7 +1778,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
Row("test", "item", false), Row("test", "dept", false),
Row("test", "person", false), Row("test", "view1", false), Row("test", "view2", false),
Row("test", "datetime", false), Row("test", "binary1", false),
Row("test", "employee_bonus", false)))
Row("test", "employee_bonus", false),
Row("test", "strings_with_nulls", false)))
}

test("SQL API: create table as select") {
Expand Down Expand Up @@ -3078,4 +3088,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
val explained = getNormalizedExplain(df, FormattedMode)
assert(explained.contains("External engine query:"))
}

test("Test not nullSafeEqual") {
val df = sql("SELECT str FROM h2.test.strings_with_nulls WHERE NOT str <=> 'abc'")
val rows = df.collect()

assert(rows.length == 2)
assert(rows.contains(Row(null)))
assert(rows.contains(Row("a a a")))
}
}

0 comments on commit cb7919b

Please sign in to comment.