Skip to content

[SPARK-26218][SQL] Overflow on arithmetic operations returns incorrect result #21599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5c662f6
[SPARK-24598][SQL] Overflow on airthmetic operation returns incorrect…
mgaido91 Jun 20, 2018
fad75fa
fix scalastyle
mgaido91 Jun 20, 2018
8591417
fix ut failures
mgaido91 Jun 20, 2018
9c3df7d
use larger intermediate buffer for sum
mgaido91 Jun 21, 2018
ebdaf61
fix UT error
mgaido91 Jun 22, 2018
a0b862e
allow precision loss when converting decimal to long
mgaido91 Jun 22, 2018
7bba22f
Merge branch 'master' into SPARK-24598
mgaido91 Jul 16, 2018
77f26f2
Merge branch 'master' of github.com:apache/spark into SPARK-24598
mgaido91 Jun 21, 2019
74cd0a4
Handle NaN
mgaido91 Jun 22, 2019
2cfd946
Add conf flag for checking overflow
mgaido91 Jun 26, 2019
25c853c
fix
mgaido91 Jun 26, 2019
ff02dca
Merge branch 'master' of github.com:apache/spark into SPARK-24598
mgaido91 Jun 27, 2019
00fae1d
fix tests
mgaido91 Jun 27, 2019
8e9715c
change default value and fix tests
mgaido91 Jun 28, 2019
1dff779
Merge branch 'master' into SPARK-24598
mgaido91 Jul 14, 2019
38fc1f4
fix typo
mgaido91 Jul 15, 2019
0d5e510
Merge branch 'SPARK-24598' of github.com:mgaido91/spark into SPARK-24598
mgaido91 Jul 15, 2019
37e19ce
fix
mgaido91 Jul 15, 2019
eb37ee7
Merge branch 'master' of github.com:apache/spark into SPARK-24598
mgaido91 Jul 20, 2019
98bbf83
address comments
mgaido91 Jul 20, 2019
650ea79
fix
mgaido91 Jul 20, 2019
1d20f73
address comments
mgaido91 Jul 26, 2019
538e332
address comments
mgaido91 Jul 26, 2019
3de4bfb
fix
mgaido91 Jul 27, 2019
3baecbc
fixes
mgaido91 Jul 27, 2019
a247f9f
fix unaryminus
mgaido91 Jul 27, 2019
582d148
address comments
mgaido91 Jul 30, 2019
b809a3f
fix
mgaido91 Jul 30, 2019
ce3ed2b
address comments
mgaido91 Jul 31, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,36 @@ import org.apache.spark.unsafe.types.CalendarInterval
""")
case class UnaryMinus(child: Expression) extends UnaryExpression
with ExpectsInputTypes with NullIntolerant {
private val checkOverflow = SQLConf.get.arithmeticOperationsFailOnOverflow

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)

override def dataType: DataType = child.dataType

override def toString: String = s"-$child"

private lazy val numeric = TypeUtils.getNumeric(dataType)
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
case ByteType | ShortType if checkOverflow =>
nullSafeCodeGen(ctx, ev, eval => {
val javaBoxedType = CodeGenerator.boxedType(dataType)
val javaType = CodeGenerator.javaType(dataType)
val originValue = ctx.freshName("origin")
s"""
|$javaType $originValue = ($javaType)($eval);
|if ($originValue == $javaBoxedType.MIN_VALUE) {
| throw new ArithmeticException("- " + $originValue + " caused overflow.");
|}
|${ev.value} = ($javaType)(-($originValue));
""".stripMargin
})
case IntegerType | LongType if checkOverflow =>
nullSafeCodeGen(ctx, ev, eval => {
val mathClass = classOf[Math].getName
s"${ev.value} = $mathClass.negateExact($eval);"
})
case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
val originValue = ctx.freshName("origin")
// codegen would fail to compile if we just write (-($c))
Expand Down Expand Up @@ -117,6 +136,8 @@ case class Abs(child: Expression)

abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {

protected val checkOverflow = SQLConf.get.arithmeticOperationsFailOnOverflow

override def dataType: DataType = left.dataType

override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess
Expand All @@ -129,17 +150,57 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
def calendarIntervalMethod: String =
sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode")

/** Name of the function for the exact version of this expression in [[Math]]. */
def exactMathMethod: String =
sys.error("BinaryArithmetics must override either exactMathMethod or genCode")

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case _: DecimalType =>
// Overflow is handled in the CheckOverflow operator
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
case CalendarIntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)")
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val tmpResult = ctx.freshName("tmpResult")
val overflowCheck = if (checkOverflow) {
val javaType = CodeGenerator.boxedType(dataType)
s"""
|if ($tmpResult < $javaType.MIN_VALUE || $tmpResult > $javaType.MAX_VALUE) {
| throw new ArithmeticException($eval1 + " $symbol " + $eval2 + " caused overflow.");
|}
""".stripMargin
} else {
""
}
s"""
|${CodeGenerator.JAVA_INT} $tmpResult = $eval1 $symbol $eval2;
|$overflowCheck
|${ev.value} = (${CodeGenerator.javaType(dataType)})($tmpResult);
""".stripMargin
})
case IntegerType | LongType =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val operation = if (checkOverflow) {
val mathClass = classOf[Math].getName
s"$mathClass.$exactMathMethod($eval1, $eval2)"
} else {
s"$eval1 $symbol $eval2"
}
s"""
|${ev.value} = $operation;
""".stripMargin
})
case DoubleType | FloatType =>
// When Double/Float overflows, there can be 2 cases:
// - precision loss: according to SQL standard, the number is truncated;
// - returns (+/-)Infinite: same behavior also other DBs have (eg. Postgres)
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
|${ev.value} = $eval1 $symbol $eval2;
""".stripMargin
})
}
}

Expand All @@ -164,7 +225,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {

override def calendarIntervalMethod: String = "add"

private lazy val numeric = TypeUtils.getNumeric(dataType)
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
if (dataType.isInstanceOf[CalendarIntervalType]) {
Expand All @@ -173,6 +234,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
numeric.plus(input1, input2)
}
}

override def exactMathMethod: String = "addExact"
}

@ExpressionDescription(
Expand All @@ -192,7 +255,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti

override def calendarIntervalMethod: String = "subtract"

private lazy val numeric = TypeUtils.getNumeric(dataType)
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
if (dataType.isInstanceOf[CalendarIntervalType]) {
Expand All @@ -201,6 +264,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
numeric.minus(input1, input2)
}
}

override def exactMathMethod: String = "subtractExact"
}

@ExpressionDescription(
Expand All @@ -217,9 +282,11 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
override def symbol: String = "*"
override def decimalMethod: String = "$times"

private lazy val numeric = TypeUtils.getNumeric(dataType)
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)

protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)

override def exactMathMethod: String = "multiplyExact"
}

// Common base trait for Divide and Remainder, since these two classes are almost identical
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,13 @@ object TypeUtils {
}
}

def getNumeric(t: DataType): Numeric[Any] =
t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]
def getNumeric(t: DataType, exactNumericRequired: Boolean = false): Numeric[Any] = {
if (exactNumericRequired) {
t.asInstanceOf[NumericType].exactNumeric.asInstanceOf[Numeric[Any]]
} else {
t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]
}
}

def getInterpretedOrdering(t: DataType): Ordering[Any] = {
t match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1780,6 +1780,15 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW =
buildConf("spark.sql.arithmeticOperations.failOnOverFlow")
.doc("If it is set to true, all arithmetic operations on non-decimal fields throw an " +
"exception if an overflow occurs. If it is false (default), in case of overflow a wrong " +
"result is returned.")
.internal()
.booleanConf
.createWithDefault(false)

val LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE =
buildConf("spark.sql.legacy.parser.havingWithoutGroupByAsWhere")
.internal()
Expand Down Expand Up @@ -2287,6 +2296,8 @@ class SQLConf extends Serializable with Logging {

def decimalOperationsNullOnOverflow: Boolean = getConf(DECIMAL_OPERATIONS_NULL_ON_OVERFLOW)

def arithmeticOperationsFailOnOverflow: Boolean = getConf(ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW)

def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION)

def continuousStreamingEpochBacklogQueueSize: Int =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ abstract class NumericType extends AtomicType {
// desugared by the compiler into an argument to the objects constructor. This means there is no
// longer a no argument constructor and thus the JVM cannot serialize the object anymore.
private[sql] val numeric: Numeric[InternalType]

private[sql] def exactNumeric: Numeric[InternalType] = numeric
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ByteType private() extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Byte]]
private[sql] val integral = implicitly[Integral[Byte]]
private[sql] val ordering = implicitly[Ordering[InternalType]]
override private[sql] val exactNumeric = ByteExactNumeric

/**
* The default size of a value of the ByteType is 1 byte.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class IntegerType private() extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Int]]
private[sql] val integral = implicitly[Integral[Int]]
private[sql] val ordering = implicitly[Ordering[InternalType]]
override private[sql] val exactNumeric = IntegerExactNumeric

/**
* The default size of a value of the IntegerType is 4 bytes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class LongType private() extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Long]]
private[sql] val integral = implicitly[Integral[Long]]
private[sql] val ordering = implicitly[Ordering[InternalType]]
override private[sql] val exactNumeric = LongExactNumeric

/**
* The default size of a value of the LongType is 8 bytes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ShortType private() extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Short]]
private[sql] val integral = implicitly[Integral[Short]]
private[sql] val ordering = implicitly[Ordering[InternalType]]
override private[sql] val exactNumeric = ShortExactNumeric

/**
* The default size of a value of the ShortType is 2 bytes.
Expand Down
110 changes: 110 additions & 0 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.types

import scala.math.Numeric.{ByteIsIntegral, IntIsIntegral, LongIsIntegral, ShortIsIntegral}
import scala.math.Ordering


object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOrdering {
private def checkOverflow(res: Int, x: Byte, y: Byte, op: String): Unit = {
if (res > Byte.MaxValue || res < Byte.MinValue) {
throw new ArithmeticException(s"$x $op $y caused overflow.")
}
}

override def plus(x: Byte, y: Byte): Byte = {
val tmp = x + y
checkOverflow(tmp, x, y, "+")
tmp.toByte
}

override def minus(x: Byte, y: Byte): Byte = {
val tmp = x - y
checkOverflow(tmp, x, y, "-")
tmp.toByte
}

override def times(x: Byte, y: Byte): Byte = {
val tmp = x * y
checkOverflow(tmp, x, y, "*")
tmp.toByte
}

override def negate(x: Byte): Byte = {
if (x == Byte.MinValue) { // if and only if x is Byte.MinValue, overflow can happen
throw new ArithmeticException(s"- $x caused overflow.")
}
(-x).toByte
}
}


object ShortExactNumeric extends ShortIsIntegral with Ordering.ShortOrdering {
private def checkOverflow(res: Int, x: Short, y: Short, op: String): Unit = {
if (res > Short.MaxValue || res < Short.MinValue) {
throw new ArithmeticException(s"$x $op $y caused overflow.")
}
}

override def plus(x: Short, y: Short): Short = {
val tmp = x + y
checkOverflow(tmp, x, y, "+")
tmp.toShort
}

override def minus(x: Short, y: Short): Short = {
val tmp = x - y
checkOverflow(tmp, x, y, "-")
tmp.toShort
}

override def times(x: Short, y: Short): Short = {
val tmp = x * y
checkOverflow(tmp, x, y, "*")
tmp.toShort
}

override def negate(x: Short): Short = {
if (x == Short.MinValue) { // if and only if x is Byte.MinValue, overflow can happen
throw new ArithmeticException(s"- $x caused overflow.")
}
(-x).toShort
}
}


object IntegerExactNumeric extends IntIsIntegral with Ordering.IntOrdering {
override def plus(x: Int, y: Int): Int = Math.addExact(x, y)

override def minus(x: Int, y: Int): Int = Math.subtractExact(x, y)

override def times(x: Int, y: Int): Int = Math.multiplyExact(x, y)

override def negate(x: Int): Int = Math.negateExact(x)
}

object LongExactNumeric extends LongIsIntegral with Ordering.LongOrdering {
override def plus(x: Long, y: Long): Long = Math.addExact(x, y)

override def minus(x: Long, y: Long): Long = Math.subtractExact(x, y)

override def times(x: Long, y: Long): Long = Math.multiplyExact(x, y)

override def negate(x: Long): Long = Math.negateExact(x)
}
Loading