Skip to content

Commit e04309c

Browse files
yaooqinncloud-fan
authored andcommitted
[SPARK-30341][SQL] Overflow check for interval arithmetic operations
### What changes were proposed in this pull request? 1. For the interval arithmetic functions, e.g. `add`/`subtract`/`negative`/`multiply`/`divide`, enable overflow check when `ANSI` is on. 2. For `multiply`/`divide`, throw an exception when an overflow happens in spite of `ANSI` is on/off. 3. `add`/`subtract`/`negative` stay the same for backward compatibility. 4. `divide` by 0 throws ArithmeticException whether `ANSI` or not as same as numerics. 5. These behaviors fit the numeric type operations fully when ANSI is on. 6. These behaviors fit the numeric type operations fully when ANSI is off, except 2 and 4. ### Why are the changes needed? 1. bug fix 2. `ANSI` support ### Does this PR introduce any user-facing change? When `ANSI` is on, interval `add`/`subtract`/`negative`/`multiply`/`divide` will overflow if any field overflows ### How was this patch tested? add unit tests Closes #26995 from yaooqinn/SPARK-30341. Authored-by: Kent Yao <yaooqinn@hotmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 68260f5 commit e04309c

File tree

10 files changed

+310
-112
lines changed

10 files changed

+310
-112
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

20-
import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry, TypeCheckResult}
20+
import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry}
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.catalyst.util.TypeUtils
2423
import org.apache.spark.sql.types._
2524

2625
@ExpressionDescription(
@@ -81,7 +80,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
8180
case _: DecimalType =>
8281
DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType)
8382
case CalendarIntervalType =>
84-
DivideInterval(sum.cast(resultType), count.cast(DoubleType))
83+
val newCount = If(EqualTo(count, Literal(0L)), Literal(null, LongType), count)
84+
DivideInterval(sum.cast(resultType), newCount.cast(DoubleType))
8585
case _ =>
8686
sum.cast(resultType) / count.cast(resultType)
8787
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,15 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
7575
"""})
7676
case _: CalendarIntervalType =>
7777
val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
78-
defineCodeGen(ctx, ev, c => s"$iu.negate($c)")
78+
val method = if (checkOverflow) "negateExact" else "negate"
79+
defineCodeGen(ctx, ev, c => s"$iu.$method($c)")
7980
}
8081

8182
protected override def nullSafeEval(input: Any): Any = dataType match {
83+
case CalendarIntervalType if checkOverflow =>
84+
IntervalUtils.negateExact(input.asInstanceOf[CalendarInterval])
8285
case CalendarIntervalType => IntervalUtils.negate(input.asInstanceOf[CalendarInterval])
83-
case _ => numeric.negate(input)
86+
case _ => numeric.negate(input)
8487
}
8588

8689
override def sql: String = s"(- ${child.sql})"
@@ -224,13 +227,17 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
224227

225228
override def decimalMethod: String = "$plus"
226229

227-
override def calendarIntervalMethod: String = "add"
230+
override def calendarIntervalMethod: String = if (checkOverflow) "addExact" else "add"
228231

229232
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)
230233

231234
protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match {
232-
case CalendarIntervalType => IntervalUtils.add(
233-
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
235+
case CalendarIntervalType if checkOverflow =>
236+
IntervalUtils.addExact(
237+
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
238+
case CalendarIntervalType =>
239+
IntervalUtils.add(
240+
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
234241
case _ => numeric.plus(input1, input2)
235242
}
236243

@@ -252,13 +259,17 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
252259

253260
override def decimalMethod: String = "$minus"
254261

255-
override def calendarIntervalMethod: String = "subtract"
262+
override def calendarIntervalMethod: String = if (checkOverflow) "subtractExact" else "subtract"
256263

257264
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)
258265

259266
protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match {
260-
case CalendarIntervalType => IntervalUtils.subtract(
261-
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
267+
case CalendarIntervalType if checkOverflow =>
268+
IntervalUtils.subtractExact(
269+
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
270+
case CalendarIntervalType =>
271+
IntervalUtils.subtract(
272+
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
262273
case _ => numeric.minus(input1, input2)
263274
}
264275

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -125,34 +125,22 @@ abstract class IntervalNumOperation(
125125
override def nullable: Boolean = true
126126

127127
override def nullSafeEval(interval: Any, num: Any): Any = {
128-
try {
129-
operation(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double])
130-
} catch {
131-
case _: java.lang.ArithmeticException => null
132-
}
128+
operation(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double])
133129
}
134130

135131
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
136-
nullSafeCodeGen(ctx, ev, (interval, num) => {
137-
val iu = IntervalUtils.getClass.getName.stripSuffix("$")
138-
s"""
139-
try {
140-
${ev.value} = $iu.$operationName($interval, $num);
141-
} catch (java.lang.ArithmeticException e) {
142-
${ev.isNull} = true;
143-
}
144-
"""
145-
})
132+
val iu = IntervalUtils.getClass.getName.stripSuffix("$")
133+
defineCodeGen(ctx, ev, (interval, num) => s"$iu.$operationName($interval, $num)")
146134
}
147135

148-
override def prettyName: String = operationName + "_interval"
136+
override def prettyName: String = operationName.stripSuffix("Exact") + "_interval"
149137
}
150138

151139
case class MultiplyInterval(interval: Expression, num: Expression)
152-
extends IntervalNumOperation(interval, num, multiply, "multiply")
140+
extends IntervalNumOperation(interval, num, multiplyExact, "multiplyExact")
153141

154142
case class DivideInterval(interval: Expression, num: Expression)
155-
extends IntervalNumOperation(interval, num, divide, "divide")
143+
extends IntervalNumOperation(interval, num, divideExact, "divideExact")
156144

157145
// scalastyle:off line.size.limit
158146
@ExpressionDescription(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ object IntervalUtils {
138138
assert(input.length == input.trim.length)
139139
input match {
140140
case yearMonthPattern("-", yearStr, monthStr) =>
141-
negate(toInterval(yearStr, monthStr))
141+
negateExact(toInterval(yearStr, monthStr))
142142
case yearMonthPattern(_, yearStr, monthStr) =>
143143
toInterval(yearStr, monthStr)
144144
case _ =>
@@ -401,6 +401,8 @@ object IntervalUtils {
401401
/**
402402
* Makes an interval from months, days and micros with the fractional part by
403403
* adding the month fraction to days and the days fraction to micros.
404+
*
405+
* @throws ArithmeticException if the result overflows any field value
404406
*/
405407
private def fromDoubles(
406408
monthsWithFraction: Double,
@@ -416,13 +418,34 @@ object IntervalUtils {
416418
/**
417419
* Unary minus, return the negated the calendar interval value.
418420
*
419-
* @param interval the interval to be negated
420-
* @return a new calendar interval instance with all it parameters negated from the origin one.
421+
* @throws ArithmeticException if the result overflows any field value
422+
*/
423+
def negateExact(interval: CalendarInterval): CalendarInterval = {
424+
val months = Math.negateExact(interval.months)
425+
val days = Math.negateExact(interval.days)
426+
val microseconds = Math.negateExact(interval.microseconds)
427+
new CalendarInterval(months, days, microseconds)
428+
}
429+
430+
/**
431+
* Unary minus, return the negated the calendar interval value.
421432
*/
422433
def negate(interval: CalendarInterval): CalendarInterval = {
423434
new CalendarInterval(-interval.months, -interval.days, -interval.microseconds)
424435
}
425436

437+
/**
438+
* Return a new calendar interval instance of the sum of two intervals.
439+
*
440+
* @throws ArithmeticException if the result overflows any field value
441+
*/
442+
def addExact(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
443+
val months = Math.addExact(left.months, right.months)
444+
val days = Math.addExact(left.days, right.days)
445+
val microseconds = Math.addExact(left.microseconds, right.microseconds)
446+
new CalendarInterval(months, days, microseconds)
447+
}
448+
426449
/**
427450
* Return a new calendar interval instance of the sum of two intervals.
428451
*/
@@ -434,7 +457,19 @@ object IntervalUtils {
434457
}
435458

436459
/**
437-
* Return a new calendar interval instance of the left intervals minus the right one.
460+
* Return a new calendar interval instance of the left interval minus the right one.
461+
*
462+
* @throws ArithmeticException if the result overflows any field value
463+
*/
464+
def subtractExact(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
465+
val months = Math.subtractExact(left.months, right.months)
466+
val days = Math.subtractExact(left.days, right.days)
467+
val microseconds = Math.subtractExact(left.microseconds, right.microseconds)
468+
new CalendarInterval(months, days, microseconds)
469+
}
470+
471+
/**
472+
* Return a new calendar interval instance of the left interval minus the right one.
438473
*/
439474
def subtract(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
440475
val months = left.months - right.months
@@ -443,12 +478,22 @@ object IntervalUtils {
443478
new CalendarInterval(months, days, microseconds)
444479
}
445480

446-
def multiply(interval: CalendarInterval, num: Double): CalendarInterval = {
481+
/**
482+
* Return a new calendar interval instance of the left interval times a multiplier.
483+
*
484+
* @throws ArithmeticException if the result overflows any field value
485+
*/
486+
def multiplyExact(interval: CalendarInterval, num: Double): CalendarInterval = {
447487
fromDoubles(num * interval.months, num * interval.days, num * interval.microseconds)
448488
}
449489

450-
def divide(interval: CalendarInterval, num: Double): CalendarInterval = {
451-
if (num == 0) throw new java.lang.ArithmeticException("divide by zero")
490+
/**
491+
* Return a new calendar interval instance of the left interval divides by a dividend.
492+
*
493+
* @throws ArithmeticException if the result overflows any field value or divided by zero
494+
*/
495+
def divideExact(interval: CalendarInterval, num: Double): CalendarInterval = {
496+
if (num == 0) throw new ArithmeticException("divide by zero")
452497
fromDoubles(interval.months / num, interval.days / num, interval.microseconds / num)
453498
}
454499

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
733733
checkEvaluation(new Sequence(
734734
Literal(Timestamp.valueOf("2018-01-02 00:00:00")),
735735
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
736-
Literal(negate(stringToInterval("interval 12 hours")))),
736+
Literal(negateExact(stringToInterval("interval 12 hours")))),
737737
Seq(
738738
Timestamp.valueOf("2018-01-02 00:00:00"),
739739
Timestamp.valueOf("2018-01-01 12:00:00"),
@@ -742,7 +742,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
742742
checkEvaluation(new Sequence(
743743
Literal(Timestamp.valueOf("2018-01-02 00:00:00")),
744744
Literal(Timestamp.valueOf("2017-12-31 23:59:59")),
745-
Literal(negate(stringToInterval("interval 12 hours")))),
745+
Literal(negateExact(stringToInterval("interval 12 hours")))),
746746
Seq(
747747
Timestamp.valueOf("2018-01-02 00:00:00"),
748748
Timestamp.valueOf("2018-01-01 12:00:00"),
@@ -760,7 +760,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
760760
checkEvaluation(new Sequence(
761761
Literal(Timestamp.valueOf("2018-03-01 00:00:00")),
762762
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
763-
Literal(negate(stringToInterval("interval 1 month")))),
763+
Literal(negateExact(stringToInterval("interval 1 month")))),
764764
Seq(
765765
Timestamp.valueOf("2018-03-01 00:00:00"),
766766
Timestamp.valueOf("2018-02-01 00:00:00"),
@@ -769,7 +769,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
769769
checkEvaluation(new Sequence(
770770
Literal(Timestamp.valueOf("2018-03-03 00:00:00")),
771771
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
772-
Literal(negate(stringToInterval("interval 1 month 1 day")))),
772+
Literal(negateExact(stringToInterval("interval 1 month 1 day")))),
773773
Seq(
774774
Timestamp.valueOf("2018-03-03 00:00:00"),
775775
Timestamp.valueOf("2018-02-02 00:00:00"),
@@ -815,7 +815,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
815815
checkEvaluation(new Sequence(
816816
Literal(Timestamp.valueOf("2022-04-01 00:00:00")),
817817
Literal(Timestamp.valueOf("2017-01-01 00:00:00")),
818-
Literal(negate(fromYearMonthString("1-5")))),
818+
Literal(negateExact(fromYearMonthString("1-5")))),
819819
Seq(
820820
Timestamp.valueOf("2022-04-01 00:00:00.000"),
821821
Timestamp.valueOf("2020-11-01 00:00:00.000"),
@@ -907,7 +907,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
907907
new Sequence(
908908
Literal(Date.valueOf("1970-01-01")),
909909
Literal(Date.valueOf("1970-02-01")),
910-
Literal(negate(stringToInterval("interval 1 month")))),
910+
Literal(negateExact(stringToInterval("interval 1 month")))),
911911
EmptyRow,
912912
s"sequence boundaries: 0 to 2678400000000 by -${28 * MICROS_PER_DAY}")
913913
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ import scala.language.implicitConversions
2121

2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
24-
import org.apache.spark.sql.catalyst.util.IntervalUtils.stringToInterval
24+
import org.apache.spark.sql.catalyst.util.IntervalUtils.{safeStringToInterval, stringToInterval}
25+
import org.apache.spark.sql.internal.SQLConf
2526
import org.apache.spark.sql.types.Decimal
2627
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
2728

@@ -198,9 +199,17 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
198199

199200
test("multiply") {
200201
def check(interval: String, num: Double, expected: String): Unit = {
201-
checkEvaluation(
202-
MultiplyInterval(Literal(stringToInterval(interval)), Literal(num)),
203-
if (expected == null) null else stringToInterval(expected))
202+
val expr = MultiplyInterval(Literal(stringToInterval(interval)), Literal(num))
203+
val expectedRes = safeStringToInterval(expected)
204+
Seq("true", "false").foreach { v =>
205+
withSQLConf(SQLConf.ANSI_ENABLED.key -> v) {
206+
if (expectedRes == null) {
207+
checkExceptionInExpression[ArithmeticException](expr, expected)
208+
} else {
209+
checkEvaluation(expr, expectedRes)
210+
}
211+
}
212+
}
204213
}
205214

206215
check("0 seconds", 10, "0 seconds")
@@ -211,14 +220,22 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
211220
check("-100 years -1 millisecond", 0.5, "-50 years -500 microseconds")
212221
check("2 months 4 seconds", -0.5, "-1 months -2 seconds")
213222
check("1 month 2 microseconds", 1.5, "1 months 15 days 3 microseconds")
214-
check("2 months", Int.MaxValue, null)
223+
check("2 months", Int.MaxValue, "integer overflow")
215224
}
216225

217226
test("divide") {
218227
def check(interval: String, num: Double, expected: String): Unit = {
219-
checkEvaluation(
220-
DivideInterval(Literal(stringToInterval(interval)), Literal(num)),
221-
if (expected == null) null else stringToInterval(expected))
228+
val expr = DivideInterval(Literal(stringToInterval(interval)), Literal(num))
229+
val expectedRes = safeStringToInterval(expected)
230+
Seq("true", "false").foreach { v =>
231+
withSQLConf(SQLConf.ANSI_ENABLED.key -> v) {
232+
if (expectedRes == null) {
233+
checkExceptionInExpression[ArithmeticException](expr, expected)
234+
} else {
235+
checkEvaluation(expr, expectedRes)
236+
}
237+
}
238+
}
222239
}
223240

224241
check("0 seconds", 10, "0 seconds")
@@ -228,7 +245,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
228245
check("2 years -8 seconds", 0.5, "4 years -16 seconds")
229246
check("-1 month 2 microseconds", -0.25, "4 months -8 microseconds")
230247
check("1 month 3 microsecond", 1.5, "20 days 2 microseconds")
231-
check("1 second", 0, null)
248+
check("1 second", 0, "divide by zero")
249+
check(s"${Int.MaxValue} months", 0.9, "integer overflow")
232250
}
233251

234252
test("make interval") {

0 commit comments

Comments
 (0)