Skip to content

Commit bd28e8e

Browse files
wangshuo128cloud-fan
authored andcommitted
[SPARK-29213][SQL] Generate extra IsNotNull predicate in FilterExec
### What changes were proposed in this pull request? Currently the behavior of getting output and generating null checks in `FilterExec` is different. Thus some nullable attribute could be treated as not nullable by mistake. In `FilterExec.ouput`, an attribute is marked as nullable or not by finding its `exprId` in notNullAttributes: ``` a.nullable && notNullAttributes.contains(a.exprId) ``` But in `FilterExec.doConsume`, a `nullCheck` is generated or not for a predicate is decided by whether there is semantic equal not null predicate: ``` val nullChecks = c.references.map { r => val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} if (idx != -1 && !generatedIsNotNullChecks(idx)) { generatedIsNotNullChecks(idx) = true // Use the child's output. The nullability is what the child produced. genPredicate(notNullPreds(idx), input, child.output) } else { "" } }.mkString("\n").trim ``` NPE will happen when run the SQL below: ``` sql("create table table1(x string)") sql("create table table2(x bigint)") sql("create table table3(x string)") sql("insert into table2 select null as x") sql( """ |select t1.x |from ( | select x from table1) t1 |left join ( | select x from ( | select x from table2 | union all | select substr(x,5) x from table3 | ) a | where length(x)>0 |) t3 |on t1.x=t3.x """.stripMargin).collect() ``` NPE Exception: ``` java.lang.NullPointerException at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(generated.java:40) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:726) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458) at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:135) at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:94) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52) at org.apache.spark.scheduler.Task.run(Task.scala:127) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:449) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:452) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` the generated code: ``` == Subtree 4 / 5 == *(2) Project [cast(x#7L as string) AS x#9] +- *(2) Filter ((length(cast(x#7L as string)) > 0) AND isnotnull(cast(x#7L as string))) +- Scan hive default.table2 [x#7L], HiveTableRelation `default`.`table2`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [x#7L] Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage2(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=2 /* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private scala.collection.Iterator inputadapter_input_0; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] filter_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2]; /* 011 */ /* 012 */ public GeneratedIteratorForCodegenStage2(Object[] references) { /* 013 */ this.references = references; /* 014 */ } /* 015 */ /* 016 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 017 */ partitionIndex = index; /* 018 */ this.inputs = inputs; /* 019 */ inputadapter_input_0 = inputs[0]; /* 020 */ filter_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 021 */ filter_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32); /* 022 */ /* 023 */ } /* 024 */ /* 025 */ protected void processNext() throws java.io.IOException { /* 026 */ while ( inputadapter_input_0.hasNext()) { /* 027 */ InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next(); /* 028 */ /* 029 */ do { /* 030 */ boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0); /* 031 */ long inputadapter_value_0 = inputadapter_isNull_0 ? /* 032 */ -1L : (inputadapter_row_0.getLong(0)); /* 033 */ /* 034 */ boolean filter_isNull_2 = inputadapter_isNull_0; /* 035 */ UTF8String filter_value_2 = null; /* 036 */ if (!inputadapter_isNull_0) { /* 037 */ filter_value_2 = UTF8String.fromString(String.valueOf(inputadapter_value_0)); /* 038 */ } /* 039 */ int filter_value_1 = -1; /* 040 */ filter_value_1 = (filter_value_2).numChars(); /* 041 */ /* 042 */ boolean filter_value_0 = false; /* 043 */ filter_value_0 = filter_value_1 > 0; /* 044 */ if (!filter_value_0) continue; /* 045 */ /* 046 */ boolean filter_isNull_6 = inputadapter_isNull_0; /* 047 */ UTF8String filter_value_6 = null; /* 048 */ if (!inputadapter_isNull_0) { /* 049 */ filter_value_6 = UTF8String.fromString(String.valueOf(inputadapter_value_0)); /* 050 */ } /* 051 */ if (!(!filter_isNull_6)) continue; /* 052 */ /* 053 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 054 */ /* 055 */ boolean project_isNull_0 = false; /* 056 */ UTF8String project_value_0 = null; /* 057 */ if (!false) { /* 058 */ project_value_0 = UTF8String.fromString(String.valueOf(inputadapter_value_0)); /* 059 */ } /* 060 */ filter_mutableStateArray_0[1].reset(); /* 061 */ /* 062 */ filter_mutableStateArray_0[1].zeroOutNullBytes(); /* 063 */ /* 064 */ if (project_isNull_0) { /* 065 */ filter_mutableStateArray_0[1].setNullAt(0); /* 066 */ } else { /* 067 */ filter_mutableStateArray_0[1].write(0, project_value_0); /* 068 */ } /* 069 */ append((filter_mutableStateArray_0[1].getRow())); /* 070 */ /* 071 */ } while(false); /* 072 */ if (shouldStop()) return; /* 073 */ } /* 074 */ } /* 075 */ /* 076 */ } ``` This PR proposes to use semantic comparison both in `FilterExec.output` and `FilterExec.doConsume` for nullable attribute. With this PR, the generated code snippet is below: ``` == Subtree 2 / 5 == *(3) Project [substring(x#8, 5, 2147483647) AS x#5] +- *(3) Filter ((length(substring(x#8, 5, 2147483647)) > 0) AND isnotnull(substring(x#8, 5, 2147483647))) +- Scan hive default.table3 [x#8], HiveTableRelation `default`.`table3`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [x#8] Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage3(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=3 /* 006 */ final class GeneratedIteratorForCodegenStage3 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private scala.collection.Iterator inputadapter_input_0; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] filter_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2]; /* 011 */ /* 012 */ public GeneratedIteratorForCodegenStage3(Object[] references) { /* 013 */ this.references = references; /* 014 */ } /* 015 */ /* 016 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 017 */ partitionIndex = index; /* 018 */ this.inputs = inputs; /* 019 */ inputadapter_input_0 = inputs[0]; /* 020 */ filter_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32); /* 021 */ filter_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32); /* 022 */ /* 023 */ } /* 024 */ /* 025 */ protected void processNext() throws java.io.IOException { /* 026 */ while ( inputadapter_input_0.hasNext()) { /* 027 */ InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next(); /* 028 */ /* 029 */ do { /* 030 */ boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0); /* 031 */ UTF8String inputadapter_value_0 = inputadapter_isNull_0 ? /* 032 */ null : (inputadapter_row_0.getUTF8String(0)); /* 033 */ /* 034 */ boolean filter_isNull_0 = true; /* 035 */ boolean filter_value_0 = false; /* 036 */ boolean filter_isNull_2 = true; /* 037 */ UTF8String filter_value_2 = null; /* 038 */ /* 039 */ if (!inputadapter_isNull_0) { /* 040 */ filter_isNull_2 = false; // resultCode could change nullability. /* 041 */ filter_value_2 = inputadapter_value_0.substringSQL(5, 2147483647); /* 042 */ /* 043 */ } /* 044 */ boolean filter_isNull_1 = filter_isNull_2; /* 045 */ int filter_value_1 = -1; /* 046 */ /* 047 */ if (!filter_isNull_2) { /* 048 */ filter_value_1 = (filter_value_2).numChars(); /* 049 */ } /* 050 */ if (!filter_isNull_1) { /* 051 */ filter_isNull_0 = false; // resultCode could change nullability. /* 052 */ filter_value_0 = filter_value_1 > 0; /* 053 */ /* 054 */ } /* 055 */ if (filter_isNull_0 || !filter_value_0) continue; /* 056 */ boolean filter_isNull_8 = true; /* 057 */ UTF8String filter_value_8 = null; /* 058 */ /* 059 */ if (!inputadapter_isNull_0) { /* 060 */ filter_isNull_8 = false; // resultCode could change nullability. /* 061 */ filter_value_8 = inputadapter_value_0.substringSQL(5, 2147483647); /* 062 */ /* 063 */ } /* 064 */ if (!(!filter_isNull_8)) continue; /* 065 */ /* 066 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 067 */ /* 068 */ boolean project_isNull_0 = true; /* 069 */ UTF8String project_value_0 = null; /* 070 */ /* 071 */ if (!inputadapter_isNull_0) { /* 072 */ project_isNull_0 = false; // resultCode could change nullability. /* 073 */ project_value_0 = inputadapter_value_0.substringSQL(5, 2147483647); /* 074 */ /* 075 */ } /* 076 */ filter_mutableStateArray_0[1].reset(); /* 077 */ /* 078 */ filter_mutableStateArray_0[1].zeroOutNullBytes(); /* 079 */ /* 080 */ if (project_isNull_0) { /* 081 */ filter_mutableStateArray_0[1].setNullAt(0); /* 082 */ } else { /* 083 */ filter_mutableStateArray_0[1].write(0, project_value_0); /* 084 */ } /* 085 */ append((filter_mutableStateArray_0[1].getRow())); /* 086 */ /* 087 */ } while(false); /* 088 */ if (shouldStop()) return; /* 089 */ } /* 090 */ } /* 091 */ /* 092 */ } ``` ### Why are the changes needed? Fix NPE bug in FilterExec. ### Does this PR introduce any user-facing change? no ### How was this patch tested? new UT Closes #25902 from wangshuo128/filter-codegen-npe. Authored-by: Wang Shuo <wangshuo128@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent aed7ff3 commit bd28e8e

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
1919

2020
import java.util.concurrent.TimeUnit._
2121

22+
import scala.collection.mutable
2223
import scala.concurrent.{ExecutionContext, Future}
2324
import scala.concurrent.duration.Duration
2425

@@ -171,13 +172,17 @@ case class FilterExec(condition: Expression, child: SparkPlan)
171172
// This is very perf sensitive.
172173
// TODO: revisit this. We can consider reordering predicates as well.
173174
val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length)
175+
val extraIsNotNullAttrs = mutable.Set[Attribute]()
174176
val generated = otherPreds.map { c =>
175177
val nullChecks = c.references.map { r =>
176178
val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)}
177179
if (idx != -1 && !generatedIsNotNullChecks(idx)) {
178180
generatedIsNotNullChecks(idx) = true
179181
// Use the child's output. The nullability is what the child produced.
180182
genPredicate(notNullPreds(idx), input, child.output)
183+
} else if (notNullAttributes.contains(r.exprId) && !extraIsNotNullAttrs.contains(r)) {
184+
extraIsNotNullAttrs += r
185+
genPredicate(IsNotNull(r), input, child.output)
181186
} else {
182187
""
183188
}

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3233,6 +3233,32 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession {
32333233
}
32343234
}
32353235
}
3236+
3237+
test("SPARK-29213: FilterExec should not throw NPE") {
3238+
withTempView("t1", "t2", "t3") {
3239+
sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t1")
3240+
sql("SELECT * FROM VALUES 0, CAST(NULL AS BIGINT)")
3241+
.as[java.lang.Long]
3242+
.map(identity)
3243+
.toDF("x")
3244+
.createOrReplaceTempView("t2")
3245+
sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t3")
3246+
sql(
3247+
"""
3248+
|SELECT t1.x
3249+
|FROM t1
3250+
|LEFT JOIN (
3251+
| SELECT x FROM (
3252+
| SELECT x FROM t2
3253+
| UNION ALL
3254+
| SELECT SUBSTR(x,5) x FROM t3
3255+
| ) a
3256+
| WHERE LENGTH(x)>0
3257+
|) t3
3258+
|ON t1.x=t3.x
3259+
""".stripMargin).collect()
3260+
}
3261+
}
32363262
}
32373263

32383264
case class Foo(bar: Option[String])

0 commit comments

Comments
 (0)