Skip to content

Commit 17601e0

Browse files
MaxGekkcloud-fan
authored andcommitted
[SPARK-34605][SQL] Support java.time.Duration as an external type of the day-time interval type
### What changes were proposed in this pull request? In the PR, I propose to extend Spark SQL API to accept [`java.time.Duration`](https://docs.oracle.com/javase/8/docs/api/java/time/Duration.html) as an external type of recently added new Catalyst type - `DayTimeIntervalType` (see #31614). The Java class `java.time.Duration` has similar semantic to ANSI SQL day-time interval type, and it is the most suitable to be an external type for `DayTimeIntervalType`. In more details: 1. Added `DurationConverter` which converts `java.time.Duration` instances to/from internal representation of the Catalyst type `DayTimeIntervalType` (to `Long` type). The `DurationConverter` object uses new methods of `IntervalUtils`: - `durationToMicros()` converts the input duration to the total length in microseconds. If this duration is too large to fit `Long`, the method throws the exception `ArithmeticException`. **Note:** _the input duration has nanosecond precision, the method casts the nanos part to microseconds by dividing by 1000._ - `microsToDuration()` obtains a `java.time.Duration` representing a number of microseconds. 2. Support new type `DayTimeIntervalType` in `RowEncoder` via the methods `createDeserializerForDuration()` and `createSerializerForJavaDuration()`. 3. Extended the Literal API to construct literals from `java.time.Duration` instances. ### Why are the changes needed? 1. To allow users parallelization of `java.time.Duration` collections, and construct day-time interval columns. Also to collect such columns back to the driver side. 2. This will allow to write tests in other sub-tasks of SPARK-27790. ### Does this PR introduce _any_ user-facing change? The PR extends existing functionality. So, users can parallelize instances of the `java.time.Duration` class and collect them back: ```Scala scala> val ds = Seq(java.time.Duration.ofDays(10)).toDS ds: org.apache.spark.sql.Dataset[java.time.Duration] = [value: daytimeinterval] scala> ds.collect res0: Array[java.time.Duration] = Array(PT240H) ``` ### How was this patch tested? - Added a few tests to `CatalystTypeConvertersSuite` to check conversion from/to `java.time.Duration`. - Checking row encoding by new tests in `RowEncoderSuite`. - Making literals of `DayTimeIntervalType` are tested in `LiteralExpressionSuite` - Check collecting by `DatasetSuite` and `JavaDatasetSuite`. Closes #31729 from MaxGekk/java-time-duration. Authored-by: Max Gekk <max.gekk@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent e7e0161 commit 17601e0

File tree

23 files changed

+229
-20
lines changed

23 files changed

+229
-20
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ public static Object read(
8383
if (handleUserDefinedType && dataType instanceof UserDefinedType) {
8484
return obj.get(ordinal, ((UserDefinedType)dataType).sqlType());
8585
}
86+
if (dataType instanceof DayTimeIntervalType) {
87+
return obj.getLong(ordinal);
88+
}
8689

8790
throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString());
8891
}

sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,14 @@ object Encoders {
135135
*/
136136
def BINARY: Encoder[Array[Byte]] = ExpressionEncoder()
137137

138+
/**
139+
* Creates an encoder that serializes instances of the `java.time.Duration` class
140+
* to the internal representation of nullable Catalyst's DayTimeIntervalType.
141+
*
142+
* @since 3.2.0
143+
*/
144+
def DURATION: Encoder[java.time.Duration] = ExpressionEncoder()
145+
138146
/**
139147
* Creates an encoder for Java Bean of type T.
140148
*

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable}
2121
import java.math.{BigDecimal => JavaBigDecimal}
2222
import java.math.{BigInteger => JavaBigInteger}
2323
import java.sql.{Date, Timestamp}
24-
import java.time.{Instant, LocalDate}
24+
import java.time.{Duration, Instant, LocalDate}
2525
import java.util.{Map => JavaMap}
2626
import javax.annotation.Nullable
2727

@@ -74,6 +74,7 @@ object CatalystTypeConverters {
7474
case LongType => LongConverter
7575
case FloatType => FloatConverter
7676
case DoubleType => DoubleConverter
77+
case DayTimeIntervalType => DurationConverter
7778
case dataType: DataType => IdentityConverter(dataType)
7879
}
7980
converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]]
@@ -400,6 +401,18 @@ object CatalystTypeConverters {
400401
override def toScalaImpl(row: InternalRow, column: Int): Double = row.getDouble(column)
401402
}
402403

404+
private object DurationConverter extends CatalystTypeConverter[Duration, Duration, Any] {
405+
override def toCatalystImpl(scalaValue: Duration): Long = {
406+
IntervalUtils.durationToMicros(scalaValue)
407+
}
408+
override def toScala(catalystValue: Any): Duration = {
409+
if (catalystValue == null) null
410+
else IntervalUtils.microsToDuration(catalystValue.asInstanceOf[Long])
411+
}
412+
override def toScalaImpl(row: InternalRow, column: Int): Duration =
413+
IntervalUtils.microsToDuration(row.getLong(column))
414+
}
415+
403416
/**
404417
* Creates a converter function that will convert Scala objects to the specified Catalyst type.
405418
* Typical use case would be converting a collection of rows that have the same schema. You will
@@ -465,6 +478,7 @@ object CatalystTypeConverters {
465478
map,
466479
(key: Any) => convertToCatalyst(key),
467480
(value: Any) => convertToCatalyst(value))
481+
case d: Duration => DurationConverter.toCatalyst(d)
468482
case other => other
469483
}
470484

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
2020
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
2121
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, UpCast}
2222
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, StaticInvoke}
23-
import org.apache.spark.sql.catalyst.util.DateTimeUtils
23+
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
2424
import org.apache.spark.sql.types._
2525

2626
object DeserializerBuildHelper {
@@ -143,6 +143,15 @@ object DeserializerBuildHelper {
143143
returnNullable = false)
144144
}
145145

146+
def createDeserializerForDuration(path: Expression): Expression = {
147+
StaticInvoke(
148+
IntervalUtils.getClass,
149+
ObjectType(classOf[java.time.Duration]),
150+
"microsToDuration",
151+
path :: Nil,
152+
returnNullable = false)
153+
}
154+
146155
/**
147156
* When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
148157
* and lost the required data type, which may lead to runtime error if the real type doesn't

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ object InternalRow {
133133
case ByteType => (input, ordinal) => input.getByte(ordinal)
134134
case ShortType => (input, ordinal) => input.getShort(ordinal)
135135
case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
136-
case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal)
136+
case LongType | TimestampType | DayTimeIntervalType =>
137+
(input, ordinal) => input.getLong(ordinal)
137138
case FloatType => (input, ordinal) => input.getFloat(ordinal)
138139
case DoubleType => (input, ordinal) => input.getDouble(ordinal)
139140
case StringType => (input, ordinal) => input.getUTF8String(ordinal)
@@ -168,7 +169,8 @@ object InternalRow {
168169
case ByteType => (input, v) => input.setByte(ordinal, v.asInstanceOf[Byte])
169170
case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short])
170171
case IntegerType | DateType => (input, v) => input.setInt(ordinal, v.asInstanceOf[Int])
171-
case LongType | TimestampType => (input, v) => input.setLong(ordinal, v.asInstanceOf[Long])
172+
case LongType | TimestampType | DayTimeIntervalType =>
173+
(input, v) => input.setLong(ordinal, v.asInstanceOf[Long])
172174
case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float])
173175
case DoubleType => (input, v) => input.setDouble(ordinal, v.asInstanceOf[Double])
174176
case CalendarIntervalType =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ object JavaTypeInference {
118118
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
119119
case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true)
120120
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
121+
case c: Class[_] if c == classOf[java.time.Duration] => (DayTimeIntervalType, true)
121122

122123
case _ if typeToken.isArray =>
123124
val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet)
@@ -249,6 +250,9 @@ object JavaTypeInference {
249250
case c if c == classOf[java.sql.Timestamp] =>
250251
createDeserializerForSqlTimestamp(path)
251252

253+
case c if c == classOf[java.time.Duration] =>
254+
createDeserializerForDuration(path)
255+
252256
case c if c == classOf[java.lang.String] =>
253257
createDeserializerForString(path, returnNullable = true)
254258

@@ -406,6 +410,8 @@ object JavaTypeInference {
406410

407411
case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject)
408412

413+
case c if c == classOf[java.time.Duration] => createSerializerForJavaDuration(inputObject)
414+
409415
case c if c == classOf[java.math.BigDecimal] =>
410416
createSerializerForJavaBigDecimal(inputObject)
411417

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ object ScalaReflection extends ScalaReflection {
240240
case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
241241
createDeserializerForSqlTimestamp(path)
242242

243+
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
244+
createDeserializerForDuration(path)
245+
243246
case t if isSubtype(t, localTypeOf[java.lang.String]) =>
244247
createDeserializerForString(path, returnNullable = false)
245248

@@ -522,6 +525,9 @@ object ScalaReflection extends ScalaReflection {
522525

523526
case t if isSubtype(t, localTypeOf[java.sql.Date]) => createSerializerForSqlDate(inputObject)
524527

528+
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
529+
createSerializerForJavaDuration(inputObject)
530+
525531
case t if isSubtype(t, localTypeOf[BigDecimal]) =>
526532
createSerializerForScalaBigDecimal(inputObject)
527533

@@ -740,6 +746,8 @@ object ScalaReflection extends ScalaReflection {
740746
case t if isSubtype(t, localTypeOf[java.sql.Date]) => Schema(DateType, nullable = true)
741747
case t if isSubtype(t, localTypeOf[CalendarInterval]) =>
742748
Schema(CalendarIntervalType, nullable = true)
749+
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
750+
Schema(DayTimeIntervalType, nullable = true)
743751
case t if isSubtype(t, localTypeOf[BigDecimal]) =>
744752
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
745753
case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
@@ -837,7 +845,8 @@ object ScalaReflection extends ScalaReflection {
837845
DateType -> classOf[DateType.InternalType],
838846
TimestampType -> classOf[TimestampType.InternalType],
839847
BinaryType -> classOf[BinaryType.InternalType],
840-
CalendarIntervalType -> classOf[CalendarInterval]
848+
CalendarIntervalType -> classOf[CalendarInterval],
849+
DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType]
841850
)
842851

843852
val typeBoxedJavaMapping = Map[DataType, Class[_]](
@@ -849,7 +858,8 @@ object ScalaReflection extends ScalaReflection {
849858
FloatType -> classOf[java.lang.Float],
850859
DoubleType -> classOf[java.lang.Double],
851860
DateType -> classOf[java.lang.Integer],
852-
TimestampType -> classOf[java.lang.Long]
861+
TimestampType -> classOf[java.lang.Long],
862+
DayTimeIntervalType -> classOf[java.lang.Long]
853863
)
854864

855865
def dataTypeJavaClass(dt: DataType): Class[_] = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst
1919

2020
import org.apache.spark.sql.catalyst.expressions.{CheckOverflow, CreateNamedStruct, Expression, IsNull, UnsafeArrayData}
2121
import org.apache.spark.sql.catalyst.expressions.objects._
22-
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
22+
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData, IntervalUtils}
2323
import org.apache.spark.sql.internal.SQLConf
2424
import org.apache.spark.sql.types._
2525
import org.apache.spark.unsafe.types.UTF8String
@@ -104,6 +104,15 @@ object SerializerBuildHelper {
104104
returnNullable = false)
105105
}
106106

107+
def createSerializerForJavaDuration(inputObject: Expression): Expression = {
108+
StaticInvoke(
109+
IntervalUtils.getClass,
110+
DayTimeIntervalType,
111+
"durationToMicros",
112+
inputObject :: Nil,
113+
returnNullable = false)
114+
}
115+
107116
def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = {
108117
CheckOverflow(StaticInvoke(
109118
Decimal.getClass,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,11 @@ package object dsl {
297297
/** Creates a new AttributeReference of type timestamp */
298298
def timestamp: AttributeReference = AttributeReference(s, TimestampType, nullable = true)()
299299

300+
/** Creates a new AttributeReference of the day-time interval type */
301+
def dayTimeInterval: AttributeReference = {
302+
AttributeReference(s, DayTimeIntervalType, nullable = true)()
303+
}
304+
300305
/** Creates a new AttributeReference of type binary */
301306
def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)()
302307

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ import org.apache.spark.sql.types._
5353
* TimestampType -> java.sql.Timestamp if spark.sql.datetime.java8API.enabled is false
5454
* TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled is true
5555
*
56+
* DayTimeIntervalType -> java.time.Duration
57+
*
5658
* BinaryType -> byte array
5759
* ArrayType -> scala.collection.Seq or Array
5860
* MapType -> scala.collection.Map
@@ -108,6 +110,8 @@ object RowEncoder {
108110
createSerializerForSqlDate(inputObject)
109111
}
110112

113+
case DayTimeIntervalType => createSerializerForJavaDuration(inputObject)
114+
111115
case d: DecimalType =>
112116
CheckOverflow(StaticInvoke(
113117
Decimal.getClass,
@@ -226,6 +230,7 @@ object RowEncoder {
226230
} else {
227231
ObjectType(classOf[java.sql.Date])
228232
}
233+
case DayTimeIntervalType => ObjectType(classOf[java.time.Duration])
229234
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
230235
case StringType => ObjectType(classOf[java.lang.String])
231236
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
@@ -281,6 +286,8 @@ object RowEncoder {
281286
createDeserializerForSqlDate(input)
282287
}
283288

289+
case DayTimeIntervalType => createDeserializerForDuration(input)
290+
284291
case _: DecimalType => createDeserializerForJavaBigDecimal(input, returnNullable = false)
285292

286293
case StringType => createDeserializerForString(input, returnNullable = false)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ object InterpretedUnsafeProjection {
160160
case IntegerType | DateType =>
161161
(v, i) => writer.write(i, v.getInt(i))
162162

163-
case LongType | TimestampType =>
163+
case LongType | TimestampType | DayTimeIntervalType =>
164164
(v, i) => writer.write(i, v.getLong(i))
165165

166166
case FloatType =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGen
195195
private[this] def dataTypeToMutableValue(dataType: DataType): MutableValue = dataType match {
196196
// We use INT for DATE internally
197197
case IntegerType | DateType => new MutableInt
198-
// We use Long for Timestamp internally
199-
case LongType | TimestampType => new MutableLong
198+
// We use Long for Timestamp and DayTimeInterval internally
199+
case LongType | TimestampType | DayTimeIntervalType => new MutableLong
200200
case FloatType => new MutableFloat
201201
case DoubleType => new MutableDouble
202202
case BooleanType => new MutableBoolean

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1813,7 +1813,7 @@ object CodeGenerator extends Logging {
18131813
case ByteType => JAVA_BYTE
18141814
case ShortType => JAVA_SHORT
18151815
case IntegerType | DateType => JAVA_INT
1816-
case LongType | TimestampType => JAVA_LONG
1816+
case LongType | TimestampType | DayTimeIntervalType => JAVA_LONG
18171817
case FloatType => JAVA_FLOAT
18181818
case DoubleType => JAVA_DOUBLE
18191819
case _: DecimalType => "Decimal"
@@ -1834,7 +1834,7 @@ object CodeGenerator extends Logging {
18341834
case ByteType => java.lang.Byte.TYPE
18351835
case ShortType => java.lang.Short.TYPE
18361836
case IntegerType | DateType => java.lang.Integer.TYPE
1837-
case LongType | TimestampType => java.lang.Long.TYPE
1837+
case LongType | TimestampType | DayTimeIntervalType => java.lang.Long.TYPE
18381838
case FloatType => java.lang.Float.TYPE
18391839
case DoubleType => java.lang.Double.TYPE
18401840
case _: DecimalType => classOf[Decimal]

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import java.lang.{Short => JavaShort}
2828
import java.math.{BigDecimal => JavaBigDecimal}
2929
import java.nio.charset.StandardCharsets
3030
import java.sql.{Date, Timestamp}
31-
import java.time.{Instant, LocalDate}
31+
import java.time.{Duration, Instant, LocalDate}
3232
import java.util
3333
import java.util.Objects
3434
import javax.xml.bind.DatatypeConverter
@@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, Scala
4343
import org.apache.spark.sql.catalyst.expressions.codegen._
4444
import org.apache.spark.sql.catalyst.util._
4545
import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros
46+
import org.apache.spark.sql.catalyst.util.IntervalUtils.durationToMicros
4647
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
4748
import org.apache.spark.sql.internal.SQLConf
4849
import org.apache.spark.sql.types._
@@ -76,6 +77,7 @@ object Literal {
7677
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
7778
case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType)
7879
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
80+
case d: Duration => Literal(durationToMicros(d), DayTimeIntervalType)
7981
case a: Array[Byte] => Literal(a, BinaryType)
8082
case a: collection.mutable.WrappedArray[_] => apply(a.array)
8183
case a: Array[_] =>
@@ -111,6 +113,7 @@ object Literal {
111113
case _ if clz == classOf[Date] => DateType
112114
case _ if clz == classOf[Instant] => TimestampType
113115
case _ if clz == classOf[Timestamp] => TimestampType
116+
case _ if clz == classOf[Duration] => DayTimeIntervalType
114117
case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT
115118
case _ if clz == classOf[Array[Byte]] => BinaryType
116119
case _ if clz == classOf[Array[Char]] => StringType
@@ -167,6 +170,7 @@ object Literal {
167170
case dt: DecimalType => Literal(Decimal(0, dt.precision, dt.scale))
168171
case DateType => create(0, DateType)
169172
case TimestampType => create(0L, TimestampType)
173+
case DayTimeIntervalType => create(0L, DayTimeIntervalType)
170174
case StringType => Literal("")
171175
case BinaryType => Literal("".getBytes(StandardCharsets.UTF_8))
172176
case CalendarIntervalType => Literal(new CalendarInterval(0, 0, 0))
@@ -186,7 +190,7 @@ object Literal {
186190
case ByteType => v.isInstanceOf[Byte]
187191
case ShortType => v.isInstanceOf[Short]
188192
case IntegerType | DateType => v.isInstanceOf[Int]
189-
case LongType | TimestampType => v.isInstanceOf[Long]
193+
case LongType | TimestampType | DayTimeIntervalType => v.isInstanceOf[Long]
190194
case FloatType => v.isInstanceOf[Float]
191195
case DoubleType => v.isInstanceOf[Double]
192196
case _: DecimalType => v.isInstanceOf[Decimal]
@@ -388,7 +392,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
388392
}
389393
case ByteType | ShortType =>
390394
ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType))
391-
case TimestampType | LongType =>
395+
case TimestampType | LongType | DayTimeIntervalType =>
392396
toExprCode(s"${value}L")
393397
case _ =>
394398
val constRef = ctx.addReferenceObj("literal", value, javaType)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.util
1919

20+
import java.time.Duration
21+
import java.time.temporal.ChronoUnit
2022
import java.util.concurrent.TimeUnit
2123

2224
import scala.util.control.NonFatal
@@ -762,4 +764,31 @@ object IntervalUtils {
762764

763765
new CalendarInterval(totalMonths, totalDays, micros)
764766
}
767+
768+
/**
769+
* Converts this duration to the total length in microseconds.
770+
* <p>
771+
* If this duration is too large to fit in a [[Long]] microseconds, then an
772+
* exception is thrown.
773+
* <p>
774+
* If this duration has greater than microsecond precision, then the conversion
775+
* will drop any excess precision information as though the amount in nanoseconds
776+
* was subject to integer division by one thousand.
777+
*
778+
* @return The total length of the duration in microseconds
779+
* @throws ArithmeticException If numeric overflow occurs
780+
*/
781+
def durationToMicros(duration: Duration): Long = {
782+
val us = Math.multiplyExact(duration.getSeconds, MICROS_PER_SECOND)
783+
val result = Math.addExact(us, duration.getNano / NANOS_PER_MICROS)
784+
result
785+
}
786+
787+
/**
788+
* Obtains a [[Duration]] representing a number of microseconds.
789+
*
790+
* @param micros The number of microseconds, positive or negative
791+
* @return A [[Duration]], not null
792+
*/
793+
def microsToDuration(micros: Long): Duration = Duration.of(micros, ChronoUnit.MICROS)
765794
}

0 commit comments

Comments
 (0)