Skip to content

Commit 79f5892

Browse files
committed
add exception for string to interval
1 parent ed3b35f commit 79f5892

File tree

16 files changed

+118
-117
lines changed

16 files changed

+118
-117
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
467467
// IntervalConverter
468468
private[this] def castToInterval(from: DataType): Any => Any = from match {
469469
case StringType =>
470-
buildCast[UTF8String](_, s => IntervalUtils.stringToInterval(s))
470+
buildCast[UTF8String](_, s => IntervalUtils.safeStringToInterval(s))
471471
}
472472

473473
// LongConverter

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2525
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY
2626
import org.apache.spark.sql.catalyst.util.IntervalUtils
2727
import org.apache.spark.sql.types._
28+
import org.apache.spark.unsafe.types.UTF8String
2829

2930
case class TimeWindow(
3031
timeColumn: Expression,
@@ -103,7 +104,7 @@ object TimeWindow {
103104
* precision.
104105
*/
105106
private def getIntervalInMicroSeconds(interval: String): Long = {
106-
val cal = IntervalUtils.fromString(interval)
107+
val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval))
107108
if (cal.months != 0) {
108109
throw new IllegalArgumentException(
109110
s"Intervals greater than a month is not supported ($interval).")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,7 +1872,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
18721872
toLiteral(stringToTimestamp(_, zoneId), TimestampType)
18731873
case "INTERVAL" =>
18741874
val interval = try {
1875-
IntervalUtils.fromString(value)
1875+
IntervalUtils.stringToInterval(UTF8String.fromString(value))
18761876
} catch {
18771877
case e: IllegalArgumentException =>
18781878
val ex = new ParseException("Cannot parse the INTERVAL value: " + value, ctx)
@@ -2082,10 +2082,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
20822082
} else {
20832083
values(i).getText
20842084
}
2085-
v + " " + u
2085+
UTF8String.fromString(" " + v + " " + u)
20862086
}
2087-
val str = kvs.mkString(" ")
2088-
IntervalUtils.fromString(str)
2087+
IntervalUtils.stringToInterval(UTF8String.concat(kvs: _*))
20892088
} catch {
20902089
case i: IllegalArgumentException =>
20912090
val e = new ParseException(i.getMessage, ctx)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala

Lines changed: 39 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -100,20 +100,6 @@ object IntervalUtils {
100100
Decimal(result, 18, 6)
101101
}
102102

103-
/**
104-
* Converts a string to [[CalendarInterval]] case-insensitively.
105-
*
106-
* @throws IllegalArgumentException if the input string is not in valid interval format.
107-
*/
108-
def fromString(str: String): CalendarInterval = {
109-
if (str == null) throw new IllegalArgumentException("Interval string cannot be null")
110-
val interval = stringToInterval(UTF8String.fromString(str))
111-
if (interval == null) {
112-
throw new IllegalArgumentException(s"Invalid interval string: $str")
113-
}
114-
interval
115-
}
116-
117103
private def toLongWithRange(
118104
fieldName: IntervalUnit,
119105
s: String,
@@ -250,30 +236,6 @@ object IntervalUtils {
250236
}
251237
}
252238

253-
/**
254-
* Parse second_nano string in ss.nnnnnnnnn format to microseconds
255-
*/
256-
private def parseSecondNano(secondNano: String): Long = {
257-
def parseSeconds(secondsStr: String): Long = {
258-
toLongWithRange(
259-
SECOND,
260-
secondsStr,
261-
Long.MinValue / MICROS_PER_SECOND,
262-
Long.MaxValue / MICROS_PER_SECOND) * MICROS_PER_SECOND
263-
}
264-
265-
secondNano.split("\\.") match {
266-
case Array(secondsStr) => parseSeconds(secondsStr)
267-
case Array("", nanosStr) => parseNanos(nanosStr, false)
268-
case Array(secondsStr, nanosStr) =>
269-
val seconds = parseSeconds(secondsStr)
270-
Math.addExact(seconds, parseNanos(nanosStr, seconds < 0))
271-
case _ =>
272-
throw new IllegalArgumentException(
273-
"Interval string does not match second-nano format of ss.nnnnnnnnn")
274-
}
275-
}
276-
277239
/**
278240
* Gets interval duration
279241
*
@@ -397,20 +359,40 @@ object IntervalUtils {
397359
private final val millisStr = unitToUtf8(MILLISECOND)
398360
private final val microsStr = unitToUtf8(MICROSECOND)
399361

362+
/**
363+
* A safe version of `stringToInterval`. It returns null for invalid input string.
364+
*/
365+
def safeStringToInterval(input: UTF8String): CalendarInterval = {
366+
try {
367+
stringToInterval(input)
368+
} catch {
369+
case _: IllegalArgumentException => null
370+
}
371+
}
372+
373+
/**
374+
* Converts a string to [[CalendarInterval]] case-insensitively.
375+
*
376+
* @throws IllegalArgumentException if the input string is not in valid interval format.
377+
*/
400378
def stringToInterval(input: UTF8String): CalendarInterval = {
401379
import ParseState._
380+
var state = PREFIX
381+
def exceptionWithState(msg: String, e: Exception = null) = {
382+
throw new IllegalArgumentException(s"Error parsing interval in state '$state', $msg", e)
383+
}
402384

403385
if (input == null) {
404-
return null
386+
exceptionWithState("interval string cannot be null")
405387
}
406388
// scalastyle:off caselocale .toLowerCase
407389
val s = input.trim.toLowerCase
408390
// scalastyle:on
409391
val bytes = s.getBytes
410392
if (bytes.isEmpty) {
411-
return null
393+
exceptionWithState("interval string cannot be empty")
412394
}
413-
var state = PREFIX
395+
414396
var i = 0
415397
var currentValue: Long = 0
416398
var isNegative: Boolean = false
@@ -427,13 +409,17 @@ object IntervalUtils {
427409
}
428410
}
429411

412+
def nextWord: UTF8String = {
413+
s.substring(i, s.numBytes()).subStringIndex(UTF8String.blankString(1), 1)
414+
}
415+
430416
while (i < bytes.length) {
431417
val b = bytes(i)
432418
state match {
433419
case PREFIX =>
434420
if (s.startsWith(intervalStr)) {
435421
if (s.numBytes() == intervalStr.numBytes()) {
436-
return null
422+
exceptionWithState("interval string cannot be empty")
437423
} else {
438424
i += intervalStr.numBytes()
439425
}
@@ -450,7 +436,7 @@ object IntervalUtils {
450436
i += 1
451437
case _ if '0' <= b && b <= '9' =>
452438
isNegative = false
453-
case _ => return null
439+
case _ => exceptionWithState( s"Unrecognized sign '$nextWord'")
454440
}
455441
currentValue = 0
456442
fraction = 0
@@ -465,13 +451,14 @@ object IntervalUtils {
465451
try {
466452
currentValue = Math.addExact(Math.multiplyExact(10, currentValue), (b - '0'))
467453
} catch {
468-
case _: ArithmeticException => return null
454+
case _: ArithmeticException =>
455+
exceptionWithState(s"'$currentValue$nextWord' out of range")
469456
}
470457
case ' ' => state = TRIM_BEFORE_UNIT
471458
case '.' =>
472459
fractionScale = (NANOS_PER_SECOND / 10).toInt
473460
state = VALUE_FRACTIONAL_PART
474-
case _ => return null
461+
case _ => exceptionWithState(s"invalid value '$nextWord'")
475462
}
476463
i += 1
477464
case VALUE_FRACTIONAL_PART =>
@@ -482,14 +469,14 @@ object IntervalUtils {
482469
case ' ' =>
483470
fraction /= NANOS_PER_MICROS.toInt
484471
state = TRIM_BEFORE_UNIT
485-
case _ => return null
472+
case _ => exceptionWithState(s"invalid value fractional part '$fraction$nextWord'")
486473
}
487474
i += 1
488475
case TRIM_BEFORE_UNIT => trimToNextState(b, UNIT_BEGIN)
489476
case UNIT_BEGIN =>
490477
// Checks that only seconds can have the fractional part
491478
if (b != 's' && fractionScale >= 0) {
492-
return null
479+
exceptionWithState(s"'$nextWord' with fractional part is unsupported")
493480
}
494481
if (isNegative) {
495482
currentValue = -currentValue
@@ -533,26 +520,26 @@ object IntervalUtils {
533520
} else if (s.matchAt(microsStr, i)) {
534521
microseconds = Math.addExact(microseconds, currentValue)
535522
i += microsStr.numBytes()
536-
} else return null
537-
case _ => return null
523+
} else exceptionWithState(s"invalid unit '$nextWord'")
524+
case _ => exceptionWithState(s"invalid unit '$nextWord'")
538525
}
539526
} catch {
540-
case _: ArithmeticException => return null
527+
case e: ArithmeticException => exceptionWithState(e.getMessage, e)
541528
}
542529
state = UNIT_SUFFIX
543530
case UNIT_SUFFIX =>
544531
b match {
545532
case 's' => state = UNIT_END
546533
case ' ' => state = TRIM_BEFORE_SIGN
547-
case _ => return null
534+
case _ => exceptionWithState(s"invalid unit suffix '$nextWord'")
548535
}
549536
i += 1
550537
case UNIT_END =>
551538
b match {
552539
case ' ' =>
553540
i += 1
554541
state = TRIM_BEFORE_SIGN
555-
case _ => return null
542+
case _ => exceptionWithState(s"invalid unit suffix '$nextWord'")
556543
}
557544
}
558545
}
@@ -562,7 +549,6 @@ object IntervalUtils {
562549
new CalendarInterval(months, days, microseconds)
563550
case _ => null
564551
}
565-
566552
result
567553
}
568554

0 commit comments

Comments
 (0)