Skip to content

Commit f100d88

Browse files
committed
use old overflow style
1 parent 7293377 commit f100d88

File tree

7 files changed

+229
-138
lines changed

7 files changed

+229
-138
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 21 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,6 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
3737
with ExpectsInputTypes with NullIntolerant {
3838
private val checkOverflow = SQLConf.get.ansiEnabled
3939

40-
override def nullable: Boolean = dataType match {
41-
case CalendarIntervalType => true
42-
case _ => child.nullable
43-
}
44-
4540
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
4641

4742
override def dataType: DataType = child.dataType
@@ -80,28 +75,19 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
8075
"""})
8176
case _: CalendarIntervalType =>
8277
val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
83-
nullSafeCodeGen(ctx, ev, interval => s"""
84-
try {
85-
${ev.value} = $iu.negate($interval);
86-
} catch (ArithmeticException e) {
87-
if ($checkOverflow) {
88-
throw new ArithmeticException("-($interval) caused interval overflow.");
89-
} else {
90-
${ev.isNull} = true;
91-
}
78+
defineCodeGen(ctx, ev,
79+
interval => if (checkOverflow) {
80+
s"$iu.negate($interval)"
81+
} else {
82+
s"$iu.safeNegate($interval)"
9283
}
93-
""")
84+
)
9485
}
9586

9687
protected override def nullSafeEval(input: Any): Any = dataType match {
97-
case CalendarIntervalType =>
98-
try {
88+
case CalendarIntervalType if checkOverflow =>
9989
IntervalUtils.negate(input.asInstanceOf[CalendarInterval])
100-
} catch {
101-
case _: ArithmeticException if checkOverflow =>
102-
throw new ArithmeticException(s"$sql caused interval overflow")
103-
case _: ArithmeticException => null
104-
}
90+
case CalendarIntervalType => IntervalUtils.safeNegate(input.asInstanceOf[CalendarInterval])
10591
case _ => numeric.negate(input)
10692
}
10793

@@ -182,19 +168,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
182168
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
183169
case CalendarIntervalType =>
184170
val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
185-
nullSafeCodeGen(ctx, ev, (eval1, eval2) =>
186-
s"""
187-
|try {
188-
| ${ev.value} = $iu.$calendarIntervalMethod($eval1, $eval2);
189-
|} catch (ArithmeticException e) {
190-
| if ($checkOverflow) {
191-
| throw new ArithmeticException(
192-
| "$eval1 $calendarIntervalMethod $eval2 caused interval overflow.");
193-
| } else {
194-
| ${ev.isNull} = true;
195-
| }
196-
|}
197-
|""".stripMargin)
171+
defineCodeGen(ctx, ev, (eval1, eval2) => s"$iu.$calendarIntervalMethod($eval1, $eval2)")
198172
// byte and short are casted into int when add, minus, times or divide
199173
case ByteType | ShortType =>
200174
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
@@ -252,31 +226,23 @@ object BinaryArithmetic {
252226
""")
253227
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
254228

255-
override def nullable: Boolean = dataType match {
256-
case CalendarIntervalType => true
257-
case _ => super.nullable
258-
}
259-
260229
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
261230

262231
override def symbol: String = "+"
263232

264233
override def decimalMethod: String = "$plus"
265234

266-
override def calendarIntervalMethod: String = "add"
235+
override def calendarIntervalMethod: String = if (checkOverflow) "add" else "safeAdd"
267236

268237
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)
269238

270239
protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match {
240+
case CalendarIntervalType if checkOverflow =>
241+
IntervalUtils.add(
242+
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
271243
case CalendarIntervalType =>
272-
try {
273-
IntervalUtils.add(
274-
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
275-
} catch {
276-
case _: ArithmeticException if checkOverflow =>
277-
throw new ArithmeticException(s"$sql causes interval overflow")
278-
case _: ArithmeticException => null
279-
}
244+
IntervalUtils.safeAdd(
245+
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
280246
case _ => numeric.plus(input1, input2)
281247
}
282248

@@ -292,31 +258,23 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
292258
""")
293259
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
294260

295-
override def nullable: Boolean = dataType match {
296-
case CalendarIntervalType => true
297-
case _ => super.nullable
298-
}
299-
300261
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
301262

302263
override def symbol: String = "-"
303264

304265
override def decimalMethod: String = "$minus"
305266

306-
override def calendarIntervalMethod: String = "subtract"
267+
override def calendarIntervalMethod: String = if (checkOverflow) "subtract" else "safeSubtract"
307268

308269
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)
309270

310271
protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match {
272+
case CalendarIntervalType if checkOverflow =>
273+
IntervalUtils.subtract(
274+
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
311275
case CalendarIntervalType =>
312-
try {
313-
IntervalUtils.subtract(
314-
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
315-
} catch {
316-
case _: ArithmeticException if checkOverflow =>
317-
throw new ArithmeticException(s"$sql caused interval overflow")
318-
case _: ArithmeticException => null
319-
}
276+
IntervalUtils.safeSubtract(
277+
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
320278
case _ => numeric.minus(input1, input2)
321279
}
322280

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,14 @@ object ExtractIntervalPart {
113113
}
114114
}
115115

116-
abstract class IntervalNumOperation(
117-
interval: Expression,
118-
num: Expression,
119-
operation: (CalendarInterval, Double) => CalendarInterval,
120-
operationName: String)
116+
abstract class IntervalNumOperation(interval: Expression, num: Expression)
121117
extends BinaryExpression with ImplicitCastInputTypes with Serializable {
122-
private val checkOverflow = SQLConf.get.ansiEnabled
118+
119+
protected val checkOverflow: Boolean = SQLConf.get.ansiEnabled
120+
121+
protected def operation(interval: CalendarInterval, num: Double): CalendarInterval
122+
123+
protected val operationName: String
123124

124125
override def left: Expression = interval
125126
override def right: Expression = num
@@ -133,9 +134,7 @@ abstract class IntervalNumOperation(
133134
try {
134135
operation(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double])
135136
} catch {
136-
case _: ArithmeticException if checkOverflow =>
137-
throw new ArithmeticException(s"$sql caused interval overflow.")
138-
case _: ArithmeticException => null
137+
case _: ArithmeticException if (!checkOverflow) => null
139138
}
140139
}
141140

@@ -147,7 +146,7 @@ abstract class IntervalNumOperation(
147146
${ev.value} = $iu.$operationName($interval, $num);
148147
} catch (ArithmeticException e) {
149148
if ($checkOverflow) {
150-
throw new ArithmeticException("$prettyName($interval, $num) caused interval overflow.");
149+
throw e;
151150
} else {
152151
${ev.isNull} = true;
153152
}
@@ -156,14 +155,28 @@ abstract class IntervalNumOperation(
156155
})
157156
}
158157

159-
override def prettyName: String = operationName + "_interval"
158+
override def prettyName: String = operationName.stripPrefix("safe").toLowerCase() + "_interval"
160159
}
161160

162161
case class MultiplyInterval(interval: Expression, num: Expression)
163-
extends IntervalNumOperation(interval, num, multiply, "multiply")
162+
extends IntervalNumOperation(interval, num) {
163+
164+
override protected def operation(interval: CalendarInterval, num: Double): CalendarInterval = {
165+
if (checkOverflow) multiply(interval, num) else safeMultiply(interval, num)
166+
}
167+
168+
override protected val operationName: String = if (checkOverflow) "multiply" else "safeMultiply"
169+
}
164170

165171
case class DivideInterval(interval: Expression, num: Expression)
166-
extends IntervalNumOperation(interval, num, divide, "divide")
172+
extends IntervalNumOperation(interval, num) {
173+
174+
override protected def operation(interval: CalendarInterval, num: Double): CalendarInterval = {
175+
if (checkOverflow) divide(interval, num) else safeDivide(interval, num)
176+
}
177+
178+
override protected val operationName: String = if (checkOverflow) "divide" else "safeDivide"
179+
}
167180

168181
// scalastyle:off line.size.limit
169182
@ExpressionDescription(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,8 @@ object IntervalUtils {
401401
/**
402402
* Makes an interval from months, days and micros with the fractional part by
403403
* adding the month fraction to days and the days fraction to micros.
404+
*
405+
* @throws ArithmeticException if the result overflows any field value
404406
*/
405407
private def fromDoubles(
406408
monthsWithFraction: Double,
@@ -413,11 +415,41 @@ object IntervalUtils {
413415
new CalendarInterval(truncatedMonths, truncatedDays, micros.round)
414416
}
415417

418+
/**
419+
* Makes an interval from months, days and micros with the fractional part by
420+
* adding the month fraction to days and the days fraction to micros.
421+
*/
422+
private def safeFromDoubles(
423+
monthsWithFraction: Double,
424+
daysWithFraction: Double,
425+
microsWithFraction: Double): CalendarInterval = {
426+
val monthInLong = monthsWithFraction.toLong
427+
val truncatedMonths = if (monthInLong > Int.MaxValue) {
428+
Int.MaxValue
429+
} else if (monthInLong < Int.MinValue) {
430+
Int.MinValue
431+
} else {
432+
monthInLong.toInt
433+
}
434+
val days = daysWithFraction + DAYS_PER_MONTH * (monthsWithFraction - truncatedMonths)
435+
val dayInLong = days.toLong
436+
val truncatedDays = if (dayInLong > Int.MaxValue) {
437+
Int.MaxValue
438+
} else if (monthInLong < Int.MinValue) {
439+
Int.MinValue
440+
} else {
441+
dayInLong.toInt
442+
}
443+
val micros = microsWithFraction + MICROS_PER_DAY * (days - truncatedDays)
444+
new CalendarInterval(truncatedMonths, truncatedDays.toInt, micros.round)
445+
}
446+
416447
/**
417448
* Unary minus, return the negated the calendar interval value.
418449
*
419450
* @param interval the interval to be negated
420451
* @return a new calendar interval instance with all it parameters negated from the origin one.
452+
* @throws ArithmeticException if the result overflows any field value
421453
*/
422454
def negate(interval: CalendarInterval): CalendarInterval = {
423455
val months = Math.negateExact(interval.months)
@@ -426,8 +458,21 @@ object IntervalUtils {
426458
new CalendarInterval(months, days, microseconds)
427459
}
428460

461+
/**
462+
* Unary minus, return the negated the calendar interval value.
463+
*
464+
* @param interval the interval to be negated
465+
* @return a new calendar interval instance with all it parameters negated from the origin one.
466+
*/
467+
def safeNegate(interval: CalendarInterval): CalendarInterval = {
468+
new CalendarInterval(-interval.months, -interval.days, -interval.microseconds)
469+
}
470+
429471
/**
430472
* Return a new calendar interval instance of the sum of two intervals.
473+
*
474+
* @throws ArithmeticException if the result overflows any field value
475+
*
431476
*/
432477
def add(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
433478
val months = Math.addExact(left.months, right.months)
@@ -437,7 +482,20 @@ object IntervalUtils {
437482
}
438483

439484
/**
440-
* Return a new calendar interval instance of the left intervals minus the right one.
485+
* Return a new calendar interval instance of the sum of two intervals.
486+
*/
487+
def safeAdd(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
488+
val months = left.months + right.months
489+
val days = left.days + right.days
490+
val microseconds = left.microseconds + right.microseconds
491+
new CalendarInterval(months, days, microseconds)
492+
}
493+
494+
/**
495+
* Return a new calendar interval instance of the left interval minus the right one.
496+
*
497+
* @throws ArithmeticException if the result overflows any field value
498+
*
441499
*/
442500
def subtract(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
443501
val months = Math.subtractExact(left.months, right.months)
@@ -446,15 +504,52 @@ object IntervalUtils {
446504
new CalendarInterval(months, days, microseconds)
447505
}
448506

507+
/**
508+
* Return a new calendar interval instance of the left interval minus the right one.
509+
*/
510+
def safeSubtract(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
511+
val months = left.months - right.months
512+
val days = left.days - right.days
513+
val microseconds = left.microseconds - right.microseconds
514+
new CalendarInterval(months, days, microseconds)
515+
}
516+
517+
/**
518+
* Return a new calendar interval instance of the left interval times a multiplier.
519+
*
520+
* @throws ArithmeticException if the result overflows any field value
521+
*/
449522
def multiply(interval: CalendarInterval, num: Double): CalendarInterval = {
450523
fromDoubles(num * interval.months, num * interval.days, num * interval.microseconds)
451524
}
452525

526+
/**
527+
* Return a new calendar interval instance of the left interval times a multiplier.
528+
*/
529+
def safeMultiply(interval: CalendarInterval, num: Double): CalendarInterval = {
530+
safeFromDoubles(num * interval.months, num * interval.days, num * interval.microseconds)
531+
}
532+
533+
/**
534+
* Return a new calendar interval instance of the left interval divides by a dividend.
535+
*
536+
* @throws ArithmeticException if the result overflows any field value or divided by zero
537+
*/
453538
def divide(interval: CalendarInterval, num: Double): CalendarInterval = {
454539
if (num == 0) throw new ArithmeticException("divide by zero")
455540
fromDoubles(interval.months / num, interval.days / num, interval.microseconds / num)
456541
}
457542

543+
/**
544+
* Return a new calendar interval instance of the left interval divides by a dividend.
545+
*
546+
* @throws ArithmeticException if divided by zero
547+
*/
548+
def safeDivide(interval: CalendarInterval, num: Double): CalendarInterval = {
549+
if (num == 0) throw new ArithmeticException("divide by zero")
550+
safeFromDoubles(interval.months / num, interval.days / num, interval.microseconds / num)
551+
}
552+
458553
// `toString` implementation in CalendarInterval is the multi-units format currently.
459554
def toMultiUnitsString(interval: CalendarInterval): String = interval.toString
460555

0 commit comments

Comments
 (0)