Skip to content

Commit 0025dde

Browse files
cloud-fangatorsmile
authored andcommitted
[SPARK-22472][SQL] add null check for top-level primitive values
## What changes were proposed in this pull request? One powerful feature of `Dataset` is, we can easily map SQL rows to Scala/Java objects and do runtime null check automatically. For example, let's say we have a parquet file with schema `<a: int, b: string>`, and we have a `case class Data(a: Int, b: String)`. Users can easily read this parquet file into `Data` objects, and Spark will throw NPE if column `a` has null values. However the null checking is left behind for top-level primitive values. For example, let's say we have a parquet file with schema `<a: Int>`, and we read it into Scala `Int`. If column `a` has null values, we will get some weird results. ``` scala> val ds = spark.read.parquet(...).as[Int] scala> ds.show() +----+ |v | +----+ |null| |1 | +----+ scala> ds.collect res0: Array[Long] = Array(0, 1) scala> ds.map(_ * 2).show +-----+ |value| +-----+ |-2 | |2 | +-----+ ``` This is because internally Spark use some special default values for primitive types, but never expect users to see/operate these default value directly. This PR adds null check for top-level primitive values ## How was this patch tested? new test Author: Wenchen Fan <wenchen@databricks.com> Closes #19707 from cloud-fan/bug.
1 parent b57ed22 commit 0025dde

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,13 @@ object ScalaReflection extends ScalaReflection {
134134
val tpe = localTypeOf[T]
135135
val clsName = getClassNameFromType(tpe)
136136
val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
137-
deserializerFor(tpe, None, walkedTypePath)
137+
val expr = deserializerFor(tpe, None, walkedTypePath)
138+
val Schema(_, nullable) = schemaFor(tpe)
139+
if (nullable) {
140+
expr
141+
} else {
142+
AssertNotNull(expr, walkedTypePath)
143+
}
138144
}
139145

140146
private def deserializerFor(

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
2424
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast}
25-
import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
25+
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance}
2626
import org.apache.spark.sql.types._
2727
import org.apache.spark.unsafe.types.UTF8String
2828

@@ -351,4 +351,9 @@ class ScalaReflectionSuite extends SparkFunSuite {
351351
assert(argumentsFields(0) == Seq("field.1"))
352352
assert(argumentsFields(1) == Seq("field 2"))
353353
}
354+
355+
test("SPARK-22472: add null check for top-level primitive values") {
356+
assert(deserializerFor[Int].isInstanceOf[AssertNotNull])
357+
assert(!deserializerFor[String].isInstanceOf[AssertNotNull])
358+
}
354359
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql
2020
import java.io.{Externalizable, ObjectInput, ObjectOutput}
2121
import java.sql.{Date, Timestamp}
2222

23+
import org.apache.spark.SparkException
2324
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
2425
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
2526
import org.apache.spark.sql.catalyst.util.sideBySide
@@ -1408,6 +1409,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
14081409
checkDataset(ds, SpecialCharClass("1", "2"))
14091410
}
14101411
}
1412+
1413+
test("SPARK-22472: add null check for top-level primitive values") {
1414+
// If the primitive values are from Option, we need to do runtime null check.
1415+
val ds = Seq(Some(1), None).toDS().as[Int]
1416+
intercept[NullPointerException](ds.collect())
1417+
val e = intercept[SparkException](ds.map(_ * 2).collect())
1418+
assert(e.getCause.isInstanceOf[NullPointerException])
1419+
1420+
withTempPath { path =>
1421+
Seq(new Integer(1), null).toDF("i").write.parquet(path.getCanonicalPath)
1422+
// If the primitive values are from files, we need to do runtime null check.
1423+
val ds = spark.read.parquet(path.getCanonicalPath).as[Int]
1424+
intercept[NullPointerException](ds.collect())
1425+
val e = intercept[SparkException](ds.map(_ * 2).collect())
1426+
assert(e.getCause.isInstanceOf[NullPointerException])
1427+
}
1428+
}
14111429
}
14121430

14131431
case class SingleData(id: Int)

0 commit comments

Comments
 (0)