diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 66708649e5646..3914c0f177dcb 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3204,6 +3204,12 @@ ], "sqlState" : "42809" }, + "NOT_NULL_ASSERT_VIOLATION" : { + "message" : [ + "NULL value appeared in non-nullable field: If the schema is inferred from a Scala tuple/case class, or a Java bean, please try to use scala.Option[_] or other nullable types (such as java.lang.Integer instead of int/scala.Int)." + ], + "sqlState" : "42000" + }, "NOT_NULL_CONSTRAINT_VIOLATION" : { "message" : [ "Assigning a NULL is not allowed here." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 462facd180c4e..32d8eebd01ce0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1917,16 +1917,12 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) override def flatArguments: Iterator[Any] = Iterator(child) - private val errMsg = "Null value appeared in non-nullable field:" + - walkedTypePath.mkString("\n", "\n", "\n") + - "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + - "please try to use scala.Option[_] or other nullable types " + - "(e.g. java.lang.Integer instead of int/scala.Int)." + private val errMsg = walkedTypePath.mkString("\n", "\n", "\n") override def eval(input: InternalRow): Any = { val result = child.eval(input) if (result == null) { - throw new NullPointerException(errMsg) + throw QueryExecutionErrors.notNullAssertViolation(errMsg) } result } @@ -1940,7 +1936,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) val code = childGen.code + code""" if (${childGen.isNull}) { - throw new NullPointerException($errMsgField); + throw QueryExecutionErrors.notNullAssertViolation($errMsgField); } """ ev.copy(code = code, isNull = FalseLiteral, value = childGen.value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 1f3283ebed059..f587d87284f3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2773,4 +2773,13 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE ) ) } + + def notNullAssertViolation(walkedTypePath: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "NOT_NULL_ASSERT_VIOLATION", + messageParameters = Map( + "walkedTypePath" -> walkedTypePath + ) + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 82238de31f9fb..9ca990b607db0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.encoders import scala.reflect.runtime.universe.TypeTag +import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -169,10 +170,17 @@ class EncoderResolutionSuite extends PlanTest { fromRow(InternalRow(new GenericArrayData(Array(1, 2)))) // If there is null value, it should throw runtime exception - val e = intercept[RuntimeException] { - fromRow(InternalRow(new GenericArrayData(Array(1, null)))) - } - assert(e.getCause.getMessage.contains("Null value appeared in non-nullable field")) + checkError( + exception = intercept[SparkRuntimeException] { + fromRow(InternalRow(new GenericArrayData(Array(1, null)))) + }, + errorClass = "EXPRESSION_DECODING_FAILED", + sqlState = "42846", + parameters = Map( + "expressions" -> + ("mapobjects(lambdavariable(MapObject, IntegerType, true, -1), " + + "assertnotnull(lambdavariable(MapObject, IntegerType, true, -1)), " + + "input[0, array, true], Some(interface scala.collection.immutable.Seq))"))) } test("the real number of fields doesn't match encoder schema: tuple encoder") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index df73d50fdcd6b..943499fde84f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.encoders import scala.collection.mutable import scala.util.Random +import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest @@ -275,9 +276,10 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { test("RowEncoder should throw RuntimeException if input row object is null") { val schema = new StructType().add("int", IntegerType) val encoder = ExpressionEncoder(schema) - val e = intercept[RuntimeException](toRow(encoder, null)) - assert(e.getCause.getMessage.contains("Null value appeared in non-nullable field")) - assert(e.getCause.getMessage.contains("top level Product or row object")) + // Check the error class only since the parameters may change depending on how we are running + // this test case. + val exception = intercept[SparkRuntimeException](toRow(encoder, null)) + assert(exception.getErrorClass == "EXPRESSION_ENCODING_FAILED") } test("RowEncoder should validate external type") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index da8e11c0433eb..ace017b1cddc3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull @@ -53,10 +53,13 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("AssertNotNUll") { - val ex = intercept[RuntimeException] { - evaluateWithoutCodegen(AssertNotNull(Literal(null))) - }.getMessage - assert(ex.contains("Null value appeared in non-nullable field")) + checkError( + exception = intercept[SparkRuntimeException] { + evaluateWithoutCodegen(AssertNotNull(Literal(null))) + }, + errorClass = "NOT_NULL_ASSERT_VIOLATION", + sqlState = "42000", + parameters = Map("walkedTypePath" -> "\n\n")) } test("IsNaN") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 16a493b52909e..10d6f045db399 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.Assertions._ import org.scalatest.exceptions.TestFailedException import org.scalatest.prop.TableDrivenPropertyChecks._ -import org.apache.spark.{SparkConf, SparkException, SparkRuntimeException, SparkUnsupportedOperationException, TaskContext} +import org.apache.spark.{SparkConf, SparkRuntimeException, SparkUnsupportedOperationException, TaskContext} import org.apache.spark.TestUtils.withListener import org.apache.spark.internal.config.MAX_RESULT_SIZE import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} @@ -1251,11 +1251,10 @@ class DatasetSuite extends QueryTest // Shouldn't throw runtime exception when parent object (`ClassData`) is null assert(buildDataset(Row(null)).collect() === Array(NestedStruct(null))) - val message = intercept[RuntimeException] { + // Just check the error class here to avoid flakiness due to different parameters. + assert(intercept[SparkRuntimeException] { buildDataset(Row(Row("hello", null))).collect() - }.getCause.getMessage - - assert(message.contains("Null value appeared in non-nullable field")) + }.getErrorClass == "EXPRESSION_DECODING_FAILED") } test("SPARK-12478: top level null field") { @@ -1593,9 +1592,8 @@ class DatasetSuite extends QueryTest } test("Dataset should throw RuntimeException if top-level product input object is null") { - val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS()) - assert(e.getCause.getMessage.contains("Null value appeared in non-nullable field")) - assert(e.getCause.getMessage.contains("top level Product or row object")) + val e = intercept[SparkRuntimeException](Seq(ClassData("a", 1), null).toDS()) + assert(e.getErrorClass == "EXPRESSION_ENCODING_FAILED") } test("dropDuplicates") { @@ -2038,19 +2036,33 @@ class DatasetSuite extends QueryTest test("SPARK-22472: add null check for top-level primitive values") { // If the primitive values are from Option, we need to do runtime null check. val ds = Seq(Some(1), None).toDS().as[Int] - val e1 = intercept[RuntimeException](ds.collect()) - assert(e1.getCause.isInstanceOf[NullPointerException]) - val e2 = intercept[SparkException](ds.map(_ * 2).collect()) - assert(e2.getCause.isInstanceOf[NullPointerException]) + val errorClass = "EXPRESSION_DECODING_FAILED" + val sqlState = "42846" + checkError( + exception = intercept[SparkRuntimeException](ds.collect()), + errorClass = "EXPRESSION_DECODING_FAILED", + sqlState = "42846", + parameters = Map("expressions" -> "assertnotnull(input[0, int, true])")) + checkError( + exception = intercept[SparkRuntimeException](ds.map(_ * 2).collect()), + errorClass = "NOT_NULL_ASSERT_VIOLATION", + sqlState = "42000", + parameters = Map("walkedTypePath" -> "\n- root class: \"int\"\n")) withTempPath { path => Seq(Integer.valueOf(1), null).toDF("i").write.parquet(path.getCanonicalPath) // If the primitive values are from files, we need to do runtime null check. val ds = spark.read.parquet(path.getCanonicalPath).as[Int] - val e1 = intercept[RuntimeException](ds.collect()) - assert(e1.getCause.isInstanceOf[NullPointerException]) - val e2 = intercept[SparkException](ds.map(_ * 2).collect()) - assert(e2.getCause.isInstanceOf[NullPointerException]) + checkError( + exception = intercept[SparkRuntimeException](ds.collect()), + errorClass = "EXPRESSION_DECODING_FAILED", + sqlState = "42846", + parameters = Map("expressions" -> "assertnotnull(input[0, int, true])")) + checkError( + exception = intercept[SparkRuntimeException](ds.map(_ * 2).collect()), + errorClass = "NOT_NULL_ASSERT_VIOLATION", + sqlState = "42000", + parameters = Map("walkedTypePath" -> "\n- root class: \"int\"\n")) } } @@ -2068,8 +2080,11 @@ class DatasetSuite extends QueryTest test("SPARK-23835: null primitive data type should throw NullPointerException") { val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() - val e = intercept[RuntimeException](ds.as[(Int, Int)].collect()) - assert(e.getCause.isInstanceOf[NullPointerException]) + checkError( + exception = intercept[SparkRuntimeException](ds.as[(Int, Int)].collect()), + errorClass = "EXPRESSION_DECODING_FAILED", + sqlState = "42846", + parameters = Map("expressions" -> "newInstance(class scala.Tuple2)")) } test("SPARK-24569: Option of primitive types are mistakenly mapped to struct type") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala index fbdd1428ba9b8..754c46cc5cd3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.util.Collections -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.{SparkConf, SparkRuntimeException} import org.apache.spark.sql.connector.catalog.{Column => ColumnV2, Identifier, InMemoryTableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.internal.SQLConf @@ -56,7 +56,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS withTable("t") { sql(s"CREATE TABLE t (s STRING, i INT NOT NULL) USING $FORMAT") - val e = intercept[SparkException] { + val e = intercept[SparkRuntimeException] { if (byName) { val inputDF = sql("SELECT 'txt' AS s, null AS i") inputDF.writeTo("t").append() @@ -64,7 +64,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS sql("INSERT INTO t VALUES ('txt', null)") } } - assertNotNullException(e, Seq("i")) + assert(e.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") } } @@ -88,7 +88,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS |USING $FORMAT """.stripMargin) - val e1 = intercept[SparkException] { + val e1 = intercept[SparkRuntimeException] { if (byName) { val inputDF = sql( s"""SELECT @@ -106,7 +106,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS } assertNotNullException(e1, Seq("s", "ns")) - val e2 = intercept[SparkException] { + val e2 = intercept[SparkRuntimeException] { if (byName) { val inputDF = sql( s"""SELECT @@ -124,7 +124,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS } assertNotNullException(e2, Seq("s", "arr")) - val e3 = intercept[SparkException] { + val e3 = intercept[SparkRuntimeException] { if (byName) { val inputDF = sql( s"""SELECT @@ -177,7 +177,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS } checkAnswer(spark.table("t"), Row(1, Row(1, null))) - val e = intercept[SparkException] { + val e = intercept[SparkRuntimeException] { if (byName) { val inputDF = sql( s"""SELECT @@ -224,7 +224,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS } checkAnswer(spark.table("t"), Row(1, null)) - val e = intercept[SparkException] { + val e = intercept[SparkRuntimeException] { if (byName) { val inputDF = sql( s"""SELECT @@ -279,7 +279,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS } checkAnswer(spark.table("t"), Row(1, List(null, Row(1, 1)))) - val e = intercept[SparkException] { + val e = intercept[SparkRuntimeException] { if (byName) { val inputDF = sql( s"""SELECT @@ -325,7 +325,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS } checkAnswer(spark.table("t"), Row(1, null)) - val e = intercept[SparkException] { + val e = intercept[SparkRuntimeException] { if (byName) { val inputDF = sql("SELECT 1 AS i, map(1, null) AS m") inputDF.writeTo("t").append() @@ -364,7 +364,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS } checkAnswer(spark.table("t"), Row(1, Map(Row(1, 1) -> null))) - val e1 = intercept[SparkException] { + val e1 = intercept[SparkRuntimeException] { if (byName) { val inputDF = sql( s"""SELECT @@ -382,7 +382,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS } assertNotNullException(e1, Seq("m", "key", "x")) - val e2 = intercept[SparkException] { + val e2 = intercept[SparkRuntimeException] { if (byName) { val inputDF = sql( s"""SELECT @@ -402,11 +402,9 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS } } - private def assertNotNullException(e: SparkException, colPath: Seq[String]): Unit = { + private def assertNotNullException(e: SparkRuntimeException, colPath: Seq[String]): Unit = { e.getCause match { - case npe: NullPointerException => - assert(npe.getMessage.contains("Null value appeared in non-nullable field")) - assert(npe.getMessage.contains(colPath.mkString("\n", "\n", "\n"))) + case _ if e.getErrorClass == "NOT_NULL_ASSERT_VIOLATION" => case other => fail(s"Unexpected exception cause: $other") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index d89c0a2525fd9..14b9feb2951a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -24,7 +24,7 @@ import java.util.Locale import scala.concurrent.duration.MICROSECONDS import scala.jdk.CollectionConverters._ -import org.apache.spark.{SparkException, SparkUnsupportedOperationException} +import org.apache.spark.{SparkException, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER import org.apache.spark.sql.catalyst.InternalRow @@ -814,14 +814,10 @@ class DataSourceV2SQLSuiteV1Filter if (nullable) { insertNullValueAndCheck() } else { - // TODO assign a error-classes name - checkError( - exception = intercept[SparkException] { - insertNullValueAndCheck() - }, - errorClass = null, - parameters = Map.empty - ) + val exception = intercept[SparkRuntimeException] { + insertNullValueAndCheck() + } + assert(exception.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 0b643ca534e39..9d4e4fc016722 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.connector -import org.apache.spark.{SparkException, SparkRuntimeException} +import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, In, Not} import org.apache.spark.sql.catalyst.optimizer.BuildLeft @@ -1317,7 +1317,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase { Seq(1, 4).toDF("pk").createOrReplaceTempView("source") - val e1 = intercept[SparkException] { + val e1 = intercept[SparkRuntimeException] { sql( s"""MERGE INTO $tableNameAsString t |USING source s @@ -1326,9 +1326,9 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase { | UPDATE SET s = named_struct('n_i', null, 'n_l', -1L) |""".stripMargin) } - assert(e1.getCause.getMessage.contains("Null value appeared in non-nullable field")) + assert(e1.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") - val e2 = intercept[SparkException] { + val e2 = intercept[SparkRuntimeException] { sql( s"""MERGE INTO $tableNameAsString t |USING source s @@ -1337,9 +1337,9 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase { | UPDATE SET s = named_struct('n_i', null, 'n_l', -1L) |""".stripMargin) } - assert(e2.getCause.getMessage.contains("Null value appeared in non-nullable field")) + assert(e2.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") - val e3 = intercept[SparkException] { + val e3 = intercept[SparkRuntimeException] { sql( s"""MERGE INTO $tableNameAsString t |USING source s @@ -1348,7 +1348,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase { | INSERT (pk, s, dep) VALUES (s.pk, named_struct('n_i', null, 'n_l', -1L), 'invalid') |""".stripMargin) } - assert(e3.getCause.getMessage.contains("Null value appeared in non-nullable field")) + assert(e3.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala index b43101c2e0255..c2ae5f40cfaf6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.connector -import org.apache.spark.SparkException +import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.Row import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue} import org.apache.spark.sql.connector.expressions.LiteralValue @@ -575,9 +575,12 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { |{ "pk": 3, "s": { "n_i": 3, "n_l": 33 }, "dep": "hr" } |""".stripMargin) - val e = intercept[SparkException] { - sql(s"UPDATE $tableNameAsString SET s = named_struct('n_i', null, 'n_l', -1L) WHERE pk = 1") - } - assert(e.getCause.getMessage.contains("Null value appeared in non-nullable field")) + checkError( + exception = intercept[SparkRuntimeException] { + sql(s"UPDATE $tableNameAsString SET s = named_struct('n_i', null, 'n_l', -1L) WHERE pk = 1") + }, + errorClass = "NOT_NULL_ASSERT_VIOLATION", + sqlState = "42000", + parameters = Map("walkedTypePath" -> "\ns\nn_i\n")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 93698fdd7bc0f..e3e385e9d1810 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -23,7 +23,7 @@ import java.time.{Duration, Period} import org.apache.hadoop.fs.{FileAlreadyExistsException, FSDataOutputStream, Path, RawLocalFileSystem} -import org.apache.spark.{SparkArithmeticException, SparkException} +import org.apache.spark.{SparkArithmeticException, SparkException, SparkRuntimeException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} @@ -953,10 +953,10 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { spark.sessionState.catalog.createTable(newTable, false) sql("INSERT INTO TABLE test_table SELECT 1, 'a'") - val msg = intercept[SparkException] { + val msg = intercept[SparkRuntimeException] { sql("INSERT INTO TABLE test_table SELECT 2, null") - }.getCause.getMessage - assert(msg.contains("Null value appeared in non-nullable field")) + } + assert(msg.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") } }