Skip to content

Commit

Permalink
[SPARK-48461][SQL] Replace NullPointerExceptions with error class in …
Browse files Browse the repository at this point in the history
…AssertNotNull expression

### What changes were proposed in this pull request?

This PR replaces `NullPointerException`s with a new error class in the `AssertNotNull` expression.

### Why are the changes needed?

We bring the advantages from the Spark error class framework to this case, enabling better user experiences and error classification.

### Does this PR introduce _any_ user-facing change?

Yes, see above.

### How was this patch tested?

This PR includes unit test coverage.

### Was this patch authored or co-authored using generative AI tooling?

GitHub copilot

Closes apache#46793 from dtenedor/fix-npe.

Authored-by: Daniel Tenedorio <daniel.tenedorio@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
dtenedor authored and HyukjinKwon committed May 30, 2024
1 parent 9e35b00 commit ce7a889
Show file tree
Hide file tree
Showing 12 changed files with 114 additions and 78 deletions.
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3204,6 +3204,12 @@
],
"sqlState" : "42809"
},
"NOT_NULL_ASSERT_VIOLATION" : {
"message" : [
"NULL value appeared in non-nullable field: <walkedTypePath>If the schema is inferred from a Scala tuple/case class, or a Java bean, please try to use scala.Option[_] or other nullable types (such as java.lang.Integer instead of int/scala.Int)."
],
"sqlState" : "42000"
},
"NOT_NULL_CONSTRAINT_VIOLATION" : {
"message" : [
"Assigning a NULL is not allowed here."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1917,16 +1917,12 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)

override def flatArguments: Iterator[Any] = Iterator(child)

private val errMsg = "Null value appeared in non-nullable field:" +
walkedTypePath.mkString("\n", "\n", "\n") +
"If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
"please try to use scala.Option[_] or other nullable types " +
"(e.g. java.lang.Integer instead of int/scala.Int)."
private val errMsg = walkedTypePath.mkString("\n", "\n", "\n")

override def eval(input: InternalRow): Any = {
val result = child.eval(input)
if (result == null) {
throw new NullPointerException(errMsg)
throw QueryExecutionErrors.notNullAssertViolation(errMsg)
}
result
}
Expand All @@ -1940,7 +1936,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)

val code = childGen.code + code"""
if (${childGen.isNull}) {
throw new NullPointerException($errMsgField);
throw QueryExecutionErrors.notNullAssertViolation($errMsgField);
}
"""
ev.copy(code = code, isNull = FalseLiteral, value = childGen.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2773,4 +2773,13 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
)
)
}

def notNullAssertViolation(walkedTypePath: String): SparkRuntimeException = {
new SparkRuntimeException(
errorClass = "NOT_NULL_ASSERT_VIOLATION",
messageParameters = Map(
"walkedTypePath" -> walkedTypePath
)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.encoders

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.SparkRuntimeException
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
Expand Down Expand Up @@ -169,10 +170,17 @@ class EncoderResolutionSuite extends PlanTest {
fromRow(InternalRow(new GenericArrayData(Array(1, 2))))

// If there is null value, it should throw runtime exception
val e = intercept[RuntimeException] {
fromRow(InternalRow(new GenericArrayData(Array(1, null))))
}
assert(e.getCause.getMessage.contains("Null value appeared in non-nullable field"))
checkError(
exception = intercept[SparkRuntimeException] {
fromRow(InternalRow(new GenericArrayData(Array(1, null))))
},
errorClass = "EXPRESSION_DECODING_FAILED",
sqlState = "42846",
parameters = Map(
"expressions" ->
("mapobjects(lambdavariable(MapObject, IntegerType, true, -1), " +
"assertnotnull(lambdavariable(MapObject, IntegerType, true, -1)), " +
"input[0, array<int>, true], Some(interface scala.collection.immutable.Seq))")))
}

test("the real number of fields doesn't match encoder schema: tuple encoder") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.encoders
import scala.collection.mutable
import scala.util.Random

import org.apache.spark.SparkRuntimeException
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest
Expand Down Expand Up @@ -275,9 +276,10 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
test("RowEncoder should throw RuntimeException if input row object is null") {
val schema = new StructType().add("int", IntegerType)
val encoder = ExpressionEncoder(schema)
val e = intercept[RuntimeException](toRow(encoder, null))
assert(e.getCause.getMessage.contains("Null value appeared in non-nullable field"))
assert(e.getCause.getMessage.contains("top level Product or row object"))
// Check the error class only since the parameters may change depending on how we are running
// this test case.
val exception = intercept[SparkRuntimeException](toRow(encoder, null))
assert(exception.getErrorClass == "EXPRESSION_ENCODING_FAILED")
}

test("RowEncoder should validate external type") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
Expand Down Expand Up @@ -53,10 +53,13 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("AssertNotNUll") {
val ex = intercept[RuntimeException] {
evaluateWithoutCodegen(AssertNotNull(Literal(null)))
}.getMessage
assert(ex.contains("Null value appeared in non-nullable field"))
checkError(
exception = intercept[SparkRuntimeException] {
evaluateWithoutCodegen(AssertNotNull(Literal(null)))
},
errorClass = "NOT_NULL_ASSERT_VIOLATION",
sqlState = "42000",
parameters = Map("walkedTypePath" -> "\n\n"))
}

test("IsNaN") {
Expand Down
51 changes: 33 additions & 18 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.scalatest.Assertions._
import org.scalatest.exceptions.TestFailedException
import org.scalatest.prop.TableDrivenPropertyChecks._

import org.apache.spark.{SparkConf, SparkException, SparkRuntimeException, SparkUnsupportedOperationException, TaskContext}
import org.apache.spark.{SparkConf, SparkRuntimeException, SparkUnsupportedOperationException, TaskContext}
import org.apache.spark.TestUtils.withListener
import org.apache.spark.internal.config.MAX_RESULT_SIZE
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
Expand Down Expand Up @@ -1251,11 +1251,10 @@ class DatasetSuite extends QueryTest
// Shouldn't throw runtime exception when parent object (`ClassData`) is null
assert(buildDataset(Row(null)).collect() === Array(NestedStruct(null)))

val message = intercept[RuntimeException] {
// Just check the error class here to avoid flakiness due to different parameters.
assert(intercept[SparkRuntimeException] {
buildDataset(Row(Row("hello", null))).collect()
}.getCause.getMessage

assert(message.contains("Null value appeared in non-nullable field"))
}.getErrorClass == "EXPRESSION_DECODING_FAILED")
}

test("SPARK-12478: top level null field") {
Expand Down Expand Up @@ -1593,9 +1592,8 @@ class DatasetSuite extends QueryTest
}

test("Dataset should throw RuntimeException if top-level product input object is null") {
val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS())
assert(e.getCause.getMessage.contains("Null value appeared in non-nullable field"))
assert(e.getCause.getMessage.contains("top level Product or row object"))
val e = intercept[SparkRuntimeException](Seq(ClassData("a", 1), null).toDS())
assert(e.getErrorClass == "EXPRESSION_ENCODING_FAILED")
}

test("dropDuplicates") {
Expand Down Expand Up @@ -2038,19 +2036,33 @@ class DatasetSuite extends QueryTest
test("SPARK-22472: add null check for top-level primitive values") {
// If the primitive values are from Option, we need to do runtime null check.
val ds = Seq(Some(1), None).toDS().as[Int]
val e1 = intercept[RuntimeException](ds.collect())
assert(e1.getCause.isInstanceOf[NullPointerException])
val e2 = intercept[SparkException](ds.map(_ * 2).collect())
assert(e2.getCause.isInstanceOf[NullPointerException])
val errorClass = "EXPRESSION_DECODING_FAILED"
val sqlState = "42846"
checkError(
exception = intercept[SparkRuntimeException](ds.collect()),
errorClass = "EXPRESSION_DECODING_FAILED",
sqlState = "42846",
parameters = Map("expressions" -> "assertnotnull(input[0, int, true])"))
checkError(
exception = intercept[SparkRuntimeException](ds.map(_ * 2).collect()),
errorClass = "NOT_NULL_ASSERT_VIOLATION",
sqlState = "42000",
parameters = Map("walkedTypePath" -> "\n- root class: \"int\"\n"))

withTempPath { path =>
Seq(Integer.valueOf(1), null).toDF("i").write.parquet(path.getCanonicalPath)
// If the primitive values are from files, we need to do runtime null check.
val ds = spark.read.parquet(path.getCanonicalPath).as[Int]
val e1 = intercept[RuntimeException](ds.collect())
assert(e1.getCause.isInstanceOf[NullPointerException])
val e2 = intercept[SparkException](ds.map(_ * 2).collect())
assert(e2.getCause.isInstanceOf[NullPointerException])
checkError(
exception = intercept[SparkRuntimeException](ds.collect()),
errorClass = "EXPRESSION_DECODING_FAILED",
sqlState = "42846",
parameters = Map("expressions" -> "assertnotnull(input[0, int, true])"))
checkError(
exception = intercept[SparkRuntimeException](ds.map(_ * 2).collect()),
errorClass = "NOT_NULL_ASSERT_VIOLATION",
sqlState = "42000",
parameters = Map("walkedTypePath" -> "\n- root class: \"int\"\n"))
}
}

Expand All @@ -2068,8 +2080,11 @@ class DatasetSuite extends QueryTest

test("SPARK-23835: null primitive data type should throw NullPointerException") {
val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS()
val e = intercept[RuntimeException](ds.as[(Int, Int)].collect())
assert(e.getCause.isInstanceOf[NullPointerException])
checkError(
exception = intercept[SparkRuntimeException](ds.as[(Int, Int)].collect()),
errorClass = "EXPRESSION_DECODING_FAILED",
sqlState = "42846",
parameters = Map("expressions" -> "newInstance(class scala.Tuple2)"))
}

test("SPARK-24569: Option of primitive types are mistakenly mapped to struct type") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql

import java.util.Collections

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.{SparkConf, SparkRuntimeException}
import org.apache.spark.sql.connector.catalog.{Column => ColumnV2, Identifier, InMemoryTableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -56,15 +56,15 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS
withTable("t") {
sql(s"CREATE TABLE t (s STRING, i INT NOT NULL) USING $FORMAT")

val e = intercept[SparkException] {
val e = intercept[SparkRuntimeException] {
if (byName) {
val inputDF = sql("SELECT 'txt' AS s, null AS i")
inputDF.writeTo("t").append()
} else {
sql("INSERT INTO t VALUES ('txt', null)")
}
}
assertNotNullException(e, Seq("i"))
assert(e.getErrorClass == "NOT_NULL_ASSERT_VIOLATION")
}
}

Expand All @@ -88,7 +88,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS
|USING $FORMAT
""".stripMargin)

val e1 = intercept[SparkException] {
val e1 = intercept[SparkRuntimeException] {
if (byName) {
val inputDF = sql(
s"""SELECT
Expand All @@ -106,7 +106,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS
}
assertNotNullException(e1, Seq("s", "ns"))

val e2 = intercept[SparkException] {
val e2 = intercept[SparkRuntimeException] {
if (byName) {
val inputDF = sql(
s"""SELECT
Expand All @@ -124,7 +124,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS
}
assertNotNullException(e2, Seq("s", "arr"))

val e3 = intercept[SparkException] {
val e3 = intercept[SparkRuntimeException] {
if (byName) {
val inputDF = sql(
s"""SELECT
Expand Down Expand Up @@ -177,7 +177,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS
}
checkAnswer(spark.table("t"), Row(1, Row(1, null)))

val e = intercept[SparkException] {
val e = intercept[SparkRuntimeException] {
if (byName) {
val inputDF = sql(
s"""SELECT
Expand Down Expand Up @@ -224,7 +224,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS
}
checkAnswer(spark.table("t"), Row(1, null))

val e = intercept[SparkException] {
val e = intercept[SparkRuntimeException] {
if (byName) {
val inputDF = sql(
s"""SELECT
Expand Down Expand Up @@ -279,7 +279,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS
}
checkAnswer(spark.table("t"), Row(1, List(null, Row(1, 1))))

val e = intercept[SparkException] {
val e = intercept[SparkRuntimeException] {
if (byName) {
val inputDF = sql(
s"""SELECT
Expand Down Expand Up @@ -325,7 +325,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS
}
checkAnswer(spark.table("t"), Row(1, null))

val e = intercept[SparkException] {
val e = intercept[SparkRuntimeException] {
if (byName) {
val inputDF = sql("SELECT 1 AS i, map(1, null) AS m")
inputDF.writeTo("t").append()
Expand Down Expand Up @@ -364,7 +364,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS
}
checkAnswer(spark.table("t"), Row(1, Map(Row(1, 1) -> null)))

val e1 = intercept[SparkException] {
val e1 = intercept[SparkRuntimeException] {
if (byName) {
val inputDF = sql(
s"""SELECT
Expand All @@ -382,7 +382,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS
}
assertNotNullException(e1, Seq("m", "key", "x"))

val e2 = intercept[SparkException] {
val e2 = intercept[SparkRuntimeException] {
if (byName) {
val inputDF = sql(
s"""SELECT
Expand All @@ -402,11 +402,9 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS
}
}

private def assertNotNullException(e: SparkException, colPath: Seq[String]): Unit = {
private def assertNotNullException(e: SparkRuntimeException, colPath: Seq[String]): Unit = {
e.getCause match {
case npe: NullPointerException =>
assert(npe.getMessage.contains("Null value appeared in non-nullable field"))
assert(npe.getMessage.contains(colPath.mkString("\n", "\n", "\n")))
case _ if e.getErrorClass == "NOT_NULL_ASSERT_VIOLATION" =>
case other =>
fail(s"Unexpected exception cause: $other")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import java.util.Locale
import scala.concurrent.duration.MICROSECONDS
import scala.jdk.CollectionConverters._

import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.{SparkException, SparkRuntimeException, SparkUnsupportedOperationException}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -814,14 +814,10 @@ class DataSourceV2SQLSuiteV1Filter
if (nullable) {
insertNullValueAndCheck()
} else {
// TODO assign a error-classes name
checkError(
exception = intercept[SparkException] {
insertNullValueAndCheck()
},
errorClass = null,
parameters = Map.empty
)
val exception = intercept[SparkRuntimeException] {
insertNullValueAndCheck()
}
assert(exception.getErrorClass == "NOT_NULL_ASSERT_VIOLATION")
}
}
}
Expand Down
Loading

0 comments on commit ce7a889

Please sign in to comment.