Skip to content

Implicit DML casts: Fix Delta version in error message and add more tests #1944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1257,8 +1257,8 @@ trait DeltaSQLConfBase {
buildConf("updateAndMergeCastingFollowsAnsiEnabledFlag")
.internal()
.doc("""If false, casting behaviour in implicit casts in UPDATE and MERGE follows
|'spark.sql.storeAssignmentPolicy'. If true, these casts follow 'ansi.enabled'. This
|was the default before Delta 3.5.""".stripMargin)
|'spark.sql.storeAssignmentPolicy'. If true, these casts follow 'ansi.enabled'.
|""".stripMargin)
.booleanConf
.createWithDefault(false)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ class ImplicitDMLCastingSuite extends QueryTest
targetType: String,
targetTypeInErrorMessage: String,
validValue: String,
overflowValue: String)
overflowValue: String,
// String because SparkArithmeticException is private and cannot be used for matching.
exceptionAnsiCast: String
) {
override def toString: String = s"sourceType: $sourceType, targetType: $targetType"
}

private case class SqlConfiguration(
followAnsiEnabled: Boolean,
Expand All @@ -59,36 +64,61 @@ class ImplicitDMLCastingSuite extends QueryTest
s" storeAssignmentPolicy: $storeAssignmentPolicy"
}

private def expectLegacyCastingBehaviour(sqlConfig: SqlConfiguration): Boolean = {
(sqlConfig.followAnsiEnabled && !sqlConfig.ansiEnabled) ||
(!sqlConfig.followAnsiEnabled &&
sqlConfig.storeAssignmentPolicy == SQLConf.StoreAssignmentPolicy.LEGACY)
}

// Note that DATE to TIMESTAMP casts are not in this list as they always throw an error on
// overflow no matter if ANSI is enabled or not.
private val testConfigurations = Seq(
TestConfiguration(sourceType = "INT", sourceTypeInErrorMessage = "INT",
targetType = "TINYINT", targetTypeInErrorMessage = "TINYINT",
validValue = "1", overflowValue = Int.MaxValue.toString),
validValue = "1", overflowValue = Int.MaxValue.toString,
exceptionAnsiCast = "SparkArithmeticException"),
TestConfiguration(sourceType = "INT", sourceTypeInErrorMessage = "INT",
targetType = "SMALLINT", targetTypeInErrorMessage = "SMALLINT",
validValue = "1", overflowValue = Int.MaxValue.toString),
validValue = "1", overflowValue = Int.MaxValue.toString,
exceptionAnsiCast = "SparkArithmeticException"),
TestConfiguration(sourceType = "BIGINT", sourceTypeInErrorMessage = "BIGINT",
targetType = "INT", targetTypeInErrorMessage = "INT",
validValue = "1", overflowValue = Long.MaxValue.toString),
validValue = "1", overflowValue = Long.MaxValue.toString,
exceptionAnsiCast = "SparkArithmeticException"),
TestConfiguration(sourceType = "DOUBLE", sourceTypeInErrorMessage = "DOUBLE",
targetType = "BIGINT", targetTypeInErrorMessage = "BIGINT",
validValue = "1", overflowValue = "12345678901234567890D"),
validValue = "1", overflowValue = "12345678901234567890D",
exceptionAnsiCast = "SparkArithmeticException"),
TestConfiguration(sourceType = "BIGINT", sourceTypeInErrorMessage = "BIGINT",
targetType = "DECIMAL(7,2)", targetTypeInErrorMessage = "DECIMAL(7,2)",
validValue = "1", overflowValue = Long.MaxValue.toString),
validValue = "1", overflowValue = Long.MaxValue.toString,
exceptionAnsiCast = "SparkArithmeticException"),
TestConfiguration(sourceType = "Struct<value:BIGINT>", sourceTypeInErrorMessage = "BIGINT",
targetType = "Struct<value:INT>", targetTypeInErrorMessage = "INT",
validValue = "named_struct('value', 1)",
overflowValue = s"named_struct('value', ${Long.MaxValue.toString})"),
overflowValue = s"named_struct('value', ${Long.MaxValue.toString})",
exceptionAnsiCast = "SparkArithmeticException"),
TestConfiguration(sourceType = "ARRAY<BIGINT>", sourceTypeInErrorMessage = "ARRAY<BIGINT>",
targetType = "ARRAY<INT>", targetTypeInErrorMessage = "ARRAY<INT>",
validValue = "ARRAY(1)", overflowValue = s"ARRAY(${Long.MaxValue.toString})")
validValue = "ARRAY(1)", overflowValue = s"ARRAY(${Long.MaxValue.toString})",
exceptionAnsiCast = "SparkArithmeticException"),
TestConfiguration(sourceType = "STRING", sourceTypeInErrorMessage = "STRING",
targetType = "INT", targetTypeInErrorMessage = "INT",
validValue = "'1'", overflowValue = s"'${Long.MaxValue.toString}'",
exceptionAnsiCast = "SparkNumberFormatException"),
TestConfiguration(sourceType = "MAP<STRING, BIGINT>",
sourceTypeInErrorMessage = "MAP<STRING, BIGINT>", targetType = "MAP<STRING, INT>",
targetTypeInErrorMessage = "MAP<STRING, INT>", validValue = "map('abc', 1)",
overflowValue = s"map('abc', ${Long.MaxValue.toString})",
exceptionAnsiCast = "SparkArithmeticException")
)

@tailrec
private def arithmeticCause(exception: Throwable): Option[ArithmeticException] = {
private def castFailureCause(exception: Throwable): Option[Throwable] = {
exception match {
case arithmeticException: ArithmeticException => Some(arithmeticException)
case _ if exception.getCause != null => arithmeticCause(exception.getCause)
case numberFormatException: NumberFormatException => Some(numberFormatException)
case _ if exception.getCause != null => castFailureCause(exception.getCause)
case _ => None
}
}
Expand All @@ -99,24 +129,35 @@ class ImplicitDMLCastingSuite extends QueryTest
*/
private def validateException(
exception: Throwable, sqlConfig: SqlConfiguration, testConfig: TestConfiguration): Unit = {
arithmeticCause(exception) match {
case Some(exception: DeltaArithmeticException) =>
assert(exception.getErrorClass == "DELTA_CAST_OVERFLOW_IN_TABLE_WRITE")
assert(exception.getMessageParameters ==
Map("sourceType" -> ("\"" + testConfig.sourceTypeInErrorMessage + "\""),
"targetType" -> ("\"" + testConfig.targetTypeInErrorMessage + "\""),
"columnName" -> "`value`",
"storeAssignmentPolicyFlag" -> SQLConf.STORE_ASSIGNMENT_POLICY.key,
"updateAndMergeCastingFollowsAnsiEnabledFlag" ->
DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG.key,
"ansiEnabledFlag" -> SQLConf.ANSI_ENABLED.key).asJava)
case Some(exception: SparkThrowable) if sqlConfig.ansiEnabled =>
// With ANSI enabled the overflows are caught before the write operation.
assert(Seq("CAST_OVERFLOW", "NUMERIC_VALUE_OUT_OF_RANGE")
.contains(exception.getErrorClass))
case None => assert(false, "No arithmetic exception thrown.")
case Some(exception) =>
assert(false, s"Unexpected exception type: $exception")
// Validate that the type of error matches the expected error type.
castFailureCause(exception) match {
case Some(failureCause) if sqlConfig.followAnsiEnabled =>
assert(sqlConfig.ansiEnabled)
assert(failureCause.toString.contains(testConfig.exceptionAnsiCast))

val sparkThrowable = failureCause.asInstanceOf[SparkThrowable]
assert(Seq("CAST_OVERFLOW", "NUMERIC_VALUE_OUT_OF_RANGE", "CAST_INVALID_INPUT")
.contains(sparkThrowable.getErrorClass))
case Some(failureCause) if !sqlConfig.followAnsiEnabled =>
assert(sqlConfig.storeAssignmentPolicy === SQLConf.StoreAssignmentPolicy.ANSI)

val sparkThrowable = failureCause.asInstanceOf[SparkThrowable]
// Only arithmetic exceptions get a custom error message.
if (testConfig.exceptionAnsiCast == "SparkArithmeticException") {
assert(sparkThrowable.getErrorClass == "DELTA_CAST_OVERFLOW_IN_TABLE_WRITE")
assert(sparkThrowable.getMessageParameters ==
Map("sourceType" -> ("\"" + testConfig.sourceTypeInErrorMessage + "\""),
"targetType" -> ("\"" + testConfig.targetTypeInErrorMessage + "\""),
"columnName" -> "`value`",
"storeAssignmentPolicyFlag" -> SQLConf.STORE_ASSIGNMENT_POLICY.key,
"updateAndMergeCastingFollowsAnsiEnabledFlag" ->
DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG.key,
"ansiEnabledFlag" -> SQLConf.ANSI_ENABLED.key).asJava)
} else {
assert(sparkThrowable.getErrorClass == "CAST_INVALID_INPUT")
assert(sparkThrowable.getMessageParameters.get("sourceType") == "\"STRING\"")
}
case None => assert(false, s"No arithmetic exception thrown: $exception")
}
}

Expand All @@ -138,7 +179,7 @@ class ImplicitDMLCastingSuite extends QueryTest
/** Test an UPDATE that requires to cast the update value that is part of the SET clause. */
private def updateTest(
sqlConfig: SqlConfiguration, testConfig: TestConfiguration): Unit = {
val testName = s"UPDATE overflow targetType: ${testConfig.targetType} $sqlConfig"
val testName = s"UPDATE overflow $testConfig $sqlConfig"
test(testName) {
sqlConfig.withSqlSettings {
val tableName = "overflowTable"
Expand All @@ -148,11 +189,7 @@ class ImplicitDMLCastingSuite extends QueryTest
|""".stripMargin)
val updateCommand = s"UPDATE $tableName SET value = ${testConfig.overflowValue}"

val legacyCasts = (sqlConfig.followAnsiEnabled && !sqlConfig.ansiEnabled) ||
(!sqlConfig.followAnsiEnabled &&
sqlConfig.storeAssignmentPolicy == SQLConf.StoreAssignmentPolicy.LEGACY)

if (legacyCasts) {
if (expectLegacyCastingBehaviour(sqlConfig)) {
sql(updateCommand)
} else {
val exception = intercept[Throwable] {
Expand Down Expand Up @@ -185,12 +222,11 @@ class ImplicitDMLCastingSuite extends QueryTest
sqlConfig: SqlConfiguration,
testConfig: TestConfiguration
): Unit = {
val testName =
s"MERGE overflow in $matchedCondition targetType: ${testConfig.targetType} $sqlConfig"
val testName = s"MERGE overflow in $matchedCondition $testConfig $sqlConfig"
test(testName) {
sqlConfig.withSqlSettings {
val targetTableName = "target_table"
val sourceViewName = "source_vice"
val sourceViewName = "source_view"
withTable(targetTableName) {
withTempView(sourceViewName) {
val numRows = 10
Expand All @@ -209,11 +245,8 @@ class ImplicitDMLCastingSuite extends QueryTest
|ON s.key = t.key
|$matchedCondition
|""".stripMargin
val legacyCasts = (sqlConfig.followAnsiEnabled && !sqlConfig.ansiEnabled) ||
(!sqlConfig.followAnsiEnabled &&
sqlConfig.storeAssignmentPolicy == SQLConf.StoreAssignmentPolicy.LEGACY)

if (legacyCasts) {
if (expectLegacyCastingBehaviour(sqlConfig)) {
sql(mergeCommand)
} else {
val exception = intercept[Throwable] {
Expand All @@ -231,7 +264,7 @@ class ImplicitDMLCastingSuite extends QueryTest
/** A merge that is executed for each batch of a stream and has to cast values before insert. */
private def streamingMergeTest(
sqlConfig: SqlConfiguration, testConfig: TestConfiguration): Unit = {
val testName = s"Streaming MERGE overflow targetType: ${testConfig.targetType} $sqlConfig"
val testName = s"Streaming MERGE overflow $testConfig $sqlConfig"
test(testName) {
sqlConfig.withSqlSettings {
val targetTableName = "target_table"
Expand Down Expand Up @@ -263,11 +296,7 @@ class ImplicitDMLCastingSuite extends QueryTest

sql(s"INSERT INTO $sourceTableName(key, value) VALUES(0, ${testConfig.overflowValue})")

val legacyCasts = (sqlConfig.followAnsiEnabled && !sqlConfig.ansiEnabled) ||
(!sqlConfig.followAnsiEnabled &&
sqlConfig.storeAssignmentPolicy == SQLConf.StoreAssignmentPolicy.LEGACY)

if (legacyCasts) {
if (expectLegacyCastingBehaviour(sqlConfig)) {
streamWriter.processAllAvailable()
} else {
val exception = intercept[Throwable] {
Expand Down