Skip to content

[SPARK-28741][SQL]Optional mode: throw exceptions when casting to integers causes overflow #25461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String

private lazy val dateFormatter = DateFormatter()
private lazy val timestampFormatter = TimestampFormatter.getFractionFormatter(zoneId)
private val failOnIntegralTypeOverflow = SQLConf.get.failOnIntegralTypeOverflow

// UDFToString
private[this] def castToString(from: DataType): Any => Any = from match {
Expand Down Expand Up @@ -461,6 +462,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
buildCast[Int](_, d => null)
case TimestampType =>
buildCast[Long](_, t => timestampToLong(t))
case x: NumericType if failOnIntegralTypeOverflow =>
b => x.exactNumeric.asInstanceOf[Numeric[Any]].toLong(b)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
}
Expand All @@ -474,8 +477,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
buildCast[Boolean](_, b => if (b) 1 else 0)
case DateType =>
buildCast[Int](_, d => null)
case TimestampType if failOnIntegralTypeOverflow =>
buildCast[Long](_, t => LongExactNumeric.toInt(timestampToLong(t)))
case TimestampType =>
buildCast[Long](_, t => timestampToLong(t).toInt)
case x: NumericType if failOnIntegralTypeOverflow =>
b => x.exactNumeric.asInstanceOf[Numeric[Any]].toInt(b)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
}
Expand All @@ -493,8 +500,30 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
case DateType =>
buildCast[Int](_, d => null)
case TimestampType if failOnIntegralTypeOverflow =>
buildCast[Long](_, t => {
val longValue = timestampToLong(t)
if (longValue == longValue.toShort) {
longValue.toShort
} else {
throw new ArithmeticException(s"Casting $t to short causes overflow.")
}
})
case TimestampType =>
buildCast[Long](_, t => timestampToLong(t).toShort)
case x: NumericType if failOnIntegralTypeOverflow =>
b =>
val intValue = try {
x.exactNumeric.asInstanceOf[Numeric[Any]].toInt(b)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you cast it into int once?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trait Numeric doesn't have the method toInt. Before this code change, the value is also casted to int.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot check the valid value range in a single place instead of the current two checks in line 520 and 525?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we can do it by match it case by case. Then the code is a bit long. Casting to short/byte should be minor usage. Also, The previous code also cast to Int before cast to Short.

} catch {
case _: ArithmeticException =>
throw new ArithmeticException(s"Casting $b to short causes overflow.")
}
if (intValue == intValue.toShort) {
intValue.toShort
} else {
throw new ArithmeticException(s"Casting $b to short causes overflow.")
}
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
}
Expand All @@ -512,8 +541,30 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
case DateType =>
buildCast[Int](_, d => null)
case TimestampType if failOnIntegralTypeOverflow =>
buildCast[Long](_, t => {
val longValue = timestampToLong(t)
if (longValue == longValue.toByte) {
longValue.toByte
} else {
throw new ArithmeticException(s"Casting $t to byte causes overflow.")
}
})
case TimestampType =>
buildCast[Long](_, t => timestampToLong(t).toByte)
case x: NumericType if failOnIntegralTypeOverflow =>
b =>
val intValue = try {
x.exactNumeric.asInstanceOf[Numeric[Any]].toInt(b)
} catch {
case _: ArithmeticException =>
throw new ArithmeticException(s"Casting $b to byte causes overflow.")
}
if (intValue == intValue.toByte) {
intValue.toByte
} else {
throw new ArithmeticException(s"Casting $b to byte causes overflow.")
}
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
}
Expand Down Expand Up @@ -1153,7 +1204,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
code"($d.toBigDecimal().bigDecimal().multiply($block)).longValue()"
}
private[this] def longToTimeStampCode(l: ExprValue): Block = code"$l * (long)$MICROS_PER_SECOND"
private[this] def timestampToIntegerCode(ts: ExprValue): Block =
private[this] def timestampToLongCode(ts: ExprValue): Block =
code"java.lang.Math.floorDiv($ts, $MICROS_PER_SECOND)"
private[this] def timestampToDoubleCode(ts: ExprValue): Block =
code"$ts / (double)$MICROS_PER_SECOND"
Expand Down Expand Up @@ -1182,6 +1233,82 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => code"$evPrim = $c != 0;"
}

private[this] def castTimestampToIntegralTypeCode(
ctx: CodegenContext,
integralType: String): CastFunction = {
if (failOnIntegralTypeOverflow) {
val longValue = ctx.freshName("longValue")
(c, evPrim, evNull) =>
code"""
long $longValue = ${timestampToLongCode(c)};
if ($longValue == ($integralType) $longValue) {
$evPrim = ($integralType) $longValue;
} else {
throw new ArithmeticException("Casting $c to $integralType causes overflow");
}
"""
} else {
(c, evPrim, evNull) => code"$evPrim = ($integralType) ${timestampToLongCode(c)};"
}
}

private[this] def castDecimalToIntegralTypeCode(
ctx: CodegenContext,
integralType: String): CastFunction = {
if (failOnIntegralTypeOverflow) {
(c, evPrim, evNull) => code"$evPrim = $c.roundTo${integralType.capitalize}();"
} else {
(c, evPrim, evNull) => code"$evPrim = $c.to${integralType.capitalize}();"
}
}

private[this] def castIntegralTypeToIntegralTypeExactCode(integralType: String): CastFunction = {
assert(failOnIntegralTypeOverflow)
(c, evPrim, evNull) =>
code"""
if ($c == ($integralType) $c) {
$evPrim = ($integralType) $c;
} else {
throw new ArithmeticException("Casting $c to $integralType causes overflow");
}
"""
}

private[this] def lowerAndUpperBound(
fractionType: String,
integralType: String): (String, String) = {
assert(fractionType == "float" || fractionType == "double")
val typeIndicator = fractionType.charAt(0)
val (min, max) = integralType.toLowerCase(Locale.ROOT) match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about

val max = s"${ctx.primitiveTypeName(integralType).MaxValue}"
val min = s"${ctx.primitiveTypeName(integralType).MinValue}"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it will be string "Int.MaxValue". Here I am trying to get string like "2147483647f" , "2147483647d"

case "long" => (Long.MinValue, Long.MaxValue)
case "int" => (Int.MinValue, Int.MaxValue)
case "short" => (Short.MinValue, Short.MaxValue)
case "byte" => (Byte.MinValue, Byte.MaxValue)
}
(min.toString + typeIndicator, max.toString + typeIndicator)
}

private[this] def castFractionToIntegralTypeCode(
fractionType: String,
integralType: String): CastFunction = {
assert(failOnIntegralTypeOverflow)
val (min, max) = lowerAndUpperBound(fractionType, integralType)
val mathClass = classOf[Math].getName
// When casting floating values to integral types, Spark uses the method `Numeric.toInt`
// Or `Numeric.toLong` directly. For positive floating values, it is equivalent to `Math.floor`;
// for negative floating values, it is equivalent to `Math.ceil`.
// So, we can use the condition `Math.floor(x) <= upperBound && Math.ceil(x) >= lowerBound`
// to check if the floating value x is in the range of an integral type after rounding.
(c, evPrim, evNull) =>
code"""
if ($mathClass.floor($c) <= $max && $mathClass.ceil($c) >= $min) {
$evPrim = ($integralType) $c;
} else {
throw new ArithmeticException("Casting $c to $integralType causes overflow");
}
"""
}

private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
Expand All @@ -1199,10 +1326,14 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => code"$evPrim = $c ? (byte) 1 : (byte) 0;"
case DateType =>
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType =>
(c, evPrim, evNull) => code"$evPrim = (byte) ${timestampToIntegerCode(c)};"
case DecimalType() =>
(c, evPrim, evNull) => code"$evPrim = $c.toByte();"
case TimestampType => castTimestampToIntegralTypeCode(ctx, "byte")
case DecimalType() => castDecimalToIntegralTypeCode(ctx, "byte")
case _: ShortType | _: IntegerType | _: LongType if failOnIntegralTypeOverflow =>
castIntegralTypeToIntegralTypeExactCode("byte")
case _: FloatType if failOnIntegralTypeOverflow =>
castFractionToIntegralTypeCode("float", "byte")
case _: DoubleType if failOnIntegralTypeOverflow =>
castFractionToIntegralTypeCode("double", "byte")
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (byte) $c;"
}
Expand All @@ -1226,10 +1357,14 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => code"$evPrim = $c ? (short) 1 : (short) 0;"
case DateType =>
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType =>
(c, evPrim, evNull) => code"$evPrim = (short) ${timestampToIntegerCode(c)};"
case DecimalType() =>
(c, evPrim, evNull) => code"$evPrim = $c.toShort();"
case TimestampType => castTimestampToIntegralTypeCode(ctx, "short")
case DecimalType() => castDecimalToIntegralTypeCode(ctx, "short")
case _: IntegerType | _: LongType if failOnIntegralTypeOverflow =>
castIntegralTypeToIntegralTypeExactCode("short")
case _: FloatType if failOnIntegralTypeOverflow =>
castFractionToIntegralTypeCode("float", "short")
case _: DoubleType if failOnIntegralTypeOverflow =>
castFractionToIntegralTypeCode("double", "short")
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (short) $c;"
}
Expand All @@ -1251,10 +1386,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => code"$evPrim = $c ? 1 : 0;"
case DateType =>
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType =>
(c, evPrim, evNull) => code"$evPrim = (int) ${timestampToIntegerCode(c)};"
case DecimalType() =>
(c, evPrim, evNull) => code"$evPrim = $c.toInt();"
case TimestampType => castTimestampToIntegralTypeCode(ctx, "int")
case DecimalType() => castDecimalToIntegralTypeCode(ctx, "int")
case _: LongType if failOnIntegralTypeOverflow => castIntegralTypeToIntegralTypeExactCode("int")
case _: FloatType if failOnIntegralTypeOverflow =>
castFractionToIntegralTypeCode("float", "int")
case _: DoubleType if failOnIntegralTypeOverflow =>
castFractionToIntegralTypeCode("double", "int")
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (int) $c;"
}
Expand All @@ -1278,9 +1416,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case DateType =>
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType =>
(c, evPrim, evNull) => code"$evPrim = (long) ${timestampToIntegerCode(c)};"
case DecimalType() =>
(c, evPrim, evNull) => code"$evPrim = $c.toLong();"
(c, evPrim, evNull) => code"$evPrim = (long) ${timestampToLongCode(c)};"
case DecimalType() => castDecimalToIntegralTypeCode(ctx, "long")
case _: FloatType if failOnIntegralTypeOverflow =>
castFractionToIntegralTypeCode("float", "long")
case _: DoubleType if failOnIntegralTypeOverflow =>
castFractionToIntegralTypeCode("double", "long")
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (long) $c;"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.unsafe.types.CalendarInterval
""")
case class UnaryMinus(child: Expression) extends UnaryExpression
with ExpectsInputTypes with NullIntolerant {
private val checkOverflow = SQLConf.get.arithmeticOperationsFailOnOverflow
private val checkOverflow = SQLConf.get.failOnIntegralTypeOverflow

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)

Expand Down Expand Up @@ -136,7 +136,7 @@ case class Abs(child: Expression)

abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {

protected val checkOverflow = SQLConf.get.arithmeticOperationsFailOnOverflow
protected val checkOverflow = SQLConf.get.failOnIntegralTypeOverflow

override def dataType: DataType = left.dataType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1805,9 +1805,9 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW =
buildConf("spark.sql.arithmeticOperations.failOnOverFlow")
.doc("If it is set to true, all arithmetic operations on non-decimal fields throw an " +
val FAIL_ON_INTEGRAL_TYPE_OVERFLOW =
buildConf("spark.sql.failOnIntegralTypeOverflow")
.doc("If it is set to true, all operations on integral fields throw an " +
"exception if an overflow occurs. If it is false (default), in case of overflow a wrong " +
"result is returned.")
.internal()
Expand Down Expand Up @@ -2321,7 +2321,7 @@ class SQLConf extends Serializable with Logging {

def decimalOperationsNullOnOverflow: Boolean = getConf(DECIMAL_OPERATIONS_NULL_ON_OVERFLOW)

def arithmeticOperationsFailOnOverflow: Boolean = getConf(ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW)
def failOnIntegralTypeOverflow: Boolean = getConf(FAIL_ON_INTEGRAL_TYPE_OVERFLOW)

def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,94 @@ final class Decimal extends Ordered[Decimal] with Serializable {

def toByte: Byte = toLong.toByte

private def overflowException(dataType: String) =
throw new ArithmeticException(s"Casting $this to $dataType causes overflow.")

/**
* @return the Byte value that is equal to the rounded decimal.
* @throws ArithmeticException if the decimal is too big to fit in Byte type.
*/
private[sql] def roundToByte(): Byte = {
if (decimalVal.eq(null)) {
val actualLongVal = longVal / POW_10(_scale)
if (actualLongVal == actualLongVal.toByte) {
actualLongVal.toByte
} else {
overflowException("byte")
}
} else {
val doubleVal = decimalVal.toDouble
if (Math.floor(doubleVal) <= Byte.MaxValue && Math.ceil(doubleVal) >= Byte.MinValue) {
doubleVal.toByte
} else {
overflowException("byte")
}
}
}

/**
* @return the Short value that is equal to the rounded decimal.
* @throws ArithmeticException if the decimal is too big to fit in Short type.
*/
private[sql] def roundToShort(): Short = {
if (decimalVal.eq(null)) {
val actualLongVal = longVal / POW_10(_scale)
if (actualLongVal == actualLongVal.toShort) {
actualLongVal.toShort
} else {
overflowException("short")
}
} else {
val doubleVal = decimalVal.toDouble
if (Math.floor(doubleVal) <= Short.MaxValue && Math.ceil(doubleVal) >= Short.MinValue) {
doubleVal.toShort
} else {
overflowException("short")
}
}
}

/**
* @return the Int value that is equal to the rounded decimal.
* @throws ArithmeticException if the decimal too big to fit in Int type.
*/
private[sql] def roundToInt(): Int = {
if (decimalVal.eq(null)) {
val actualLongVal = longVal / POW_10(_scale)
if (actualLongVal == actualLongVal.toInt) {
actualLongVal.toInt
} else {
overflowException("int")
}
} else {
val doubleVal = decimalVal.toDouble
if (Math.floor(doubleVal) <= Int.MaxValue && Math.ceil(doubleVal) >= Int.MinValue) {
doubleVal.toInt
} else {
overflowException("int")
}
}
}

/**
* @return the Long value that is equal to the rounded decimal.
* @throws ArithmeticException if the decimal too big to fit in Long type.
*/
private[sql] def roundToLong(): Long = {
if (decimalVal.eq(null)) {
longVal / POW_10(_scale)
} else {
try {
// We cannot store Long.MAX_VALUE as a Double without losing precision.
// Here we simply convert the decimal to `BigInteger` and use the method
// `longValueExact` to make sure the range check is accurate.
decimalVal.bigDecimal.toBigInteger.longValueExact()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also use the <= max && >= min check here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current implementation is simple and accurate.
If we convert the value to double, then it won't be accurate;
If we compare the value with another Decimal, then internally both values are converted to BigDecimal.

} catch {
case _: ArithmeticException => overflowException("long")
}
}
}

/**
* Update precision and scale while keeping our value the same, and return true if successful.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
private[sql] val ordering = Decimal.DecimalIsFractional
private[sql] val asIntegral = Decimal.DecimalAsIfIntegral

override private[sql] def exactNumeric = DecimalExactNumeric

override def typeName: String = s"decimal($precision,$scale)"

override def toString: String = s"DecimalType($precision,$scale)"
Expand Down
Loading