Skip to content

[SPARK-34615][SQL] Support java.time.Period as an external type of the year-month interval type #31765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ public static Object read(
if (dataType instanceof DayTimeIntervalType) {
return obj.getLong(ordinal);
}
if (dataType instanceof YearMonthIntervalType) {
return obj.getInt(ordinal);
}

throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ object Encoders {
*/
def DURATION: Encoder[java.time.Duration] = ExpressionEncoder()

/**
* Creates an encoder that serializes instances of the `java.time.Period` class
* to the internal representation of nullable Catalyst's YearMonthIntervalType.
*
* @since 3.2.0
*/
def PERIOD: Encoder[java.time.Period] = ExpressionEncoder()

/**
* Creates an encoder for Java Bean of type T.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable}
import java.math.{BigDecimal => JavaBigDecimal}
import java.math.{BigInteger => JavaBigInteger}
import java.sql.{Date, Timestamp}
import java.time.{Duration, Instant, LocalDate}
import java.time.{Duration, Instant, LocalDate, Period}
import java.util.{Map => JavaMap}
import javax.annotation.Nullable

Expand Down Expand Up @@ -75,6 +75,7 @@ object CatalystTypeConverters {
case FloatType => FloatConverter
case DoubleType => DoubleConverter
case DayTimeIntervalType => DurationConverter
case YearMonthIntervalType => PeriodConverter
case dataType: DataType => IdentityConverter(dataType)
}
converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]]
Expand Down Expand Up @@ -413,6 +414,18 @@ object CatalystTypeConverters {
IntervalUtils.microsToDuration(row.getLong(column))
}

private object PeriodConverter extends CatalystTypeConverter[Period, Period, Any] {
override def toCatalystImpl(scalaValue: Period): Int = {
IntervalUtils.periodToMonths(scalaValue)
}
override def toScala(catalystValue: Any): Period = {
if (catalystValue == null) null
else IntervalUtils.monthsToPeriod(catalystValue.asInstanceOf[Int])
}
override def toScalaImpl(row: InternalRow, column: Int): Period =
IntervalUtils.monthsToPeriod(row.getInt(column))
}

/**
* Creates a converter function that will convert Scala objects to the specified Catalyst type.
* Typical use case would be converting a collection of rows that have the same schema. You will
Expand Down Expand Up @@ -479,6 +492,7 @@ object CatalystTypeConverters {
(key: Any) => convertToCatalyst(key),
(value: Any) => convertToCatalyst(value))
case d: Duration => DurationConverter.toCatalyst(d)
case p: Period => PeriodConverter.toCatalyst(p)
case other => other
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@ object DeserializerBuildHelper {
returnNullable = false)
}

def createDeserializerForPeriod(path: Expression): Expression = {
StaticInvoke(
IntervalUtils.getClass,
ObjectType(classOf[java.time.Period]),
"monthsToPeriod",
path :: Nil,
returnNullable = false)
}

/**
* When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
* and lost the required data type, which may lead to runtime error if the real type doesn't
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ object InternalRow {
case BooleanType => (input, ordinal) => input.getBoolean(ordinal)
case ByteType => (input, ordinal) => input.getByte(ordinal)
case ShortType => (input, ordinal) => input.getShort(ordinal)
case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
case IntegerType | DateType | YearMonthIntervalType =>
(input, ordinal) => input.getInt(ordinal)
case LongType | TimestampType | DayTimeIntervalType =>
(input, ordinal) => input.getLong(ordinal)
case FloatType => (input, ordinal) => input.getFloat(ordinal)
Expand Down Expand Up @@ -168,7 +169,8 @@ object InternalRow {
case BooleanType => (input, v) => input.setBoolean(ordinal, v.asInstanceOf[Boolean])
case ByteType => (input, v) => input.setByte(ordinal, v.asInstanceOf[Byte])
case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short])
case IntegerType | DateType => (input, v) => input.setInt(ordinal, v.asInstanceOf[Int])
case IntegerType | DateType | YearMonthIntervalType =>
(input, v) => input.setInt(ordinal, v.asInstanceOf[Int])
case LongType | TimestampType | DayTimeIntervalType =>
(input, v) => input.setLong(ordinal, v.asInstanceOf[Long])
case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ object JavaTypeInference {
case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
case c: Class[_] if c == classOf[java.time.Duration] => (DayTimeIntervalType, true)
case c: Class[_] if c == classOf[java.time.Period] => (YearMonthIntervalType, true)

case _ if typeToken.isArray =>
val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet)
Expand Down Expand Up @@ -253,6 +254,9 @@ object JavaTypeInference {
case c if c == classOf[java.time.Duration] =>
createDeserializerForDuration(path)

case c if c == classOf[java.time.Period] =>
createDeserializerForPeriod(path)

case c if c == classOf[java.lang.String] =>
createDeserializerForString(path, returnNullable = true)

Expand Down Expand Up @@ -412,6 +416,8 @@ object JavaTypeInference {

case c if c == classOf[java.time.Duration] => createSerializerForJavaDuration(inputObject)

case c if c == classOf[java.time.Period] => createSerializerForJavaPeriod(inputObject)

case c if c == classOf[java.math.BigDecimal] =>
createSerializerForJavaBigDecimal(inputObject)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
createDeserializerForDuration(path)

case t if isSubtype(t, localTypeOf[java.time.Period]) =>
createDeserializerForPeriod(path)

case t if isSubtype(t, localTypeOf[java.lang.String]) =>
createDeserializerForString(path, returnNullable = false)

Expand Down Expand Up @@ -528,6 +531,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
createSerializerForJavaDuration(inputObject)

case t if isSubtype(t, localTypeOf[java.time.Period]) =>
createSerializerForJavaPeriod(inputObject)

case t if isSubtype(t, localTypeOf[BigDecimal]) =>
createSerializerForScalaBigDecimal(inputObject)

Expand Down Expand Up @@ -748,6 +754,8 @@ object ScalaReflection extends ScalaReflection {
Schema(CalendarIntervalType, nullable = true)
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
Schema(DayTimeIntervalType, nullable = true)
case t if isSubtype(t, localTypeOf[java.time.Period]) =>
Schema(YearMonthIntervalType, nullable = true)
case t if isSubtype(t, localTypeOf[BigDecimal]) =>
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
Expand Down Expand Up @@ -846,7 +854,8 @@ object ScalaReflection extends ScalaReflection {
TimestampType -> classOf[TimestampType.InternalType],
BinaryType -> classOf[BinaryType.InternalType],
CalendarIntervalType -> classOf[CalendarInterval],
DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType]
DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType],
YearMonthIntervalType -> classOf[YearMonthIntervalType.InternalType]
)

val typeBoxedJavaMapping = Map[DataType, Class[_]](
Expand All @@ -859,7 +868,8 @@ object ScalaReflection extends ScalaReflection {
DoubleType -> classOf[java.lang.Double],
DateType -> classOf[java.lang.Integer],
TimestampType -> classOf[java.lang.Long],
DayTimeIntervalType -> classOf[java.lang.Long]
DayTimeIntervalType -> classOf[java.lang.Long],
YearMonthIntervalType -> classOf[java.lang.Integer]
)

def dataTypeJavaClass(dt: DataType): Class[_] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ object SerializerBuildHelper {
returnNullable = false)
}

def createSerializerForJavaPeriod(inputObject: Expression): Expression = {
StaticInvoke(
IntervalUtils.getClass,
YearMonthIntervalType,
"periodToMonths",
inputObject :: Nil,
returnNullable = false)
}

def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = {
CheckOverflow(StaticInvoke(
Decimal.getClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,11 @@ package object dsl {
AttributeReference(s, DayTimeIntervalType, nullable = true)()
}

/** Creates a new AttributeReference of the year-month interval type */
def yearMonthInterval: AttributeReference = {
AttributeReference(s, YearMonthIntervalType, nullable = true)()
}

/** Creates a new AttributeReference of type binary */
def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import org.apache.spark.sql.types._
* TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled is true
*
* DayTimeIntervalType -> java.time.Duration
* YearMonthIntervalType -> java.time.Period
*
* BinaryType -> byte array
* ArrayType -> scala.collection.Seq or Array
Expand Down Expand Up @@ -112,6 +113,8 @@ object RowEncoder {

case DayTimeIntervalType => createSerializerForJavaDuration(inputObject)

case YearMonthIntervalType => createSerializerForJavaPeriod(inputObject)

case d: DecimalType =>
CheckOverflow(StaticInvoke(
Decimal.getClass,
Expand Down Expand Up @@ -231,6 +234,7 @@ object RowEncoder {
ObjectType(classOf[java.sql.Date])
}
case DayTimeIntervalType => ObjectType(classOf[java.time.Duration])
case YearMonthIntervalType => ObjectType(classOf[java.time.Period])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
case StringType => ObjectType(classOf[java.lang.String])
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
Expand Down Expand Up @@ -288,6 +292,8 @@ object RowEncoder {

case DayTimeIntervalType => createDeserializerForDuration(input)

case YearMonthIntervalType => createDeserializerForPeriod(input)

case _: DecimalType => createDeserializerForJavaBigDecimal(input, returnNullable = false)

case StringType => createDeserializerForString(input, returnNullable = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ object InterpretedUnsafeProjection {
case ShortType =>
(v, i) => writer.write(i, v.getShort(i))

case IntegerType | DateType =>
case IntegerType | DateType | YearMonthIntervalType =>
(v, i) => writer.write(i, v.getInt(i))

case LongType | TimestampType | DayTimeIntervalType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ final class MutableAny extends MutableValue {
final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGenericInternalRow {

private[this] def dataTypeToMutableValue(dataType: DataType): MutableValue = dataType match {
// We use INT for DATE internally
case IntegerType | DateType => new MutableInt
// We use INT for DATE and YearMonthIntervalType internally
case IntegerType | DateType | YearMonthIntervalType => new MutableInt
// We use Long for Timestamp and DayTimeInterval internally
case LongType | TimestampType | DayTimeIntervalType => new MutableLong
case FloatType => new MutableFloat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1812,7 +1812,7 @@ object CodeGenerator extends Logging {
case BooleanType => JAVA_BOOLEAN
case ByteType => JAVA_BYTE
case ShortType => JAVA_SHORT
case IntegerType | DateType => JAVA_INT
case IntegerType | DateType | YearMonthIntervalType => JAVA_INT
case LongType | TimestampType | DayTimeIntervalType => JAVA_LONG
case FloatType => JAVA_FLOAT
case DoubleType => JAVA_DOUBLE
Expand All @@ -1833,7 +1833,7 @@ object CodeGenerator extends Logging {
case BooleanType => java.lang.Boolean.TYPE
case ByteType => java.lang.Byte.TYPE
case ShortType => java.lang.Short.TYPE
case IntegerType | DateType => java.lang.Integer.TYPE
case IntegerType | DateType | YearMonthIntervalType => java.lang.Integer.TYPE
case LongType | TimestampType | DayTimeIntervalType => java.lang.Long.TYPE
case FloatType => java.lang.Float.TYPE
case DoubleType => java.lang.Double.TYPE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import java.lang.{Short => JavaShort}
import java.math.{BigDecimal => JavaBigDecimal}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.time.{Duration, Instant, LocalDate}
import java.time.{Duration, Instant, LocalDate, Period}
import java.util
import java.util.Objects
import javax.xml.bind.DatatypeConverter
Expand All @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, Scala
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros
import org.apache.spark.sql.catalyst.util.IntervalUtils.durationToMicros
import org.apache.spark.sql.catalyst.util.IntervalUtils.{durationToMicros, periodToMonths}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -78,6 +78,7 @@ object Literal {
case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType)
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
case d: Duration => Literal(durationToMicros(d), DayTimeIntervalType)
case p: Period => Literal(periodToMonths(p), YearMonthIntervalType)
case a: Array[Byte] => Literal(a, BinaryType)
case a: collection.mutable.WrappedArray[_] => apply(a.array)
case a: Array[_] =>
Expand Down Expand Up @@ -114,6 +115,7 @@ object Literal {
case _ if clz == classOf[Instant] => TimestampType
case _ if clz == classOf[Timestamp] => TimestampType
case _ if clz == classOf[Duration] => DayTimeIntervalType
case _ if clz == classOf[Period] => YearMonthIntervalType
case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT
case _ if clz == classOf[Array[Byte]] => BinaryType
case _ if clz == classOf[Array[Char]] => StringType
Expand Down Expand Up @@ -171,6 +173,7 @@ object Literal {
case DateType => create(0, DateType)
case TimestampType => create(0L, TimestampType)
case DayTimeIntervalType => create(0L, DayTimeIntervalType)
case YearMonthIntervalType => create(0, YearMonthIntervalType)
case StringType => Literal("")
case BinaryType => Literal("".getBytes(StandardCharsets.UTF_8))
case CalendarIntervalType => Literal(new CalendarInterval(0, 0, 0))
Expand All @@ -189,7 +192,7 @@ object Literal {
case BooleanType => v.isInstanceOf[Boolean]
case ByteType => v.isInstanceOf[Byte]
case ShortType => v.isInstanceOf[Short]
case IntegerType | DateType => v.isInstanceOf[Int]
case IntegerType | DateType | YearMonthIntervalType => v.isInstanceOf[Int]
case LongType | TimestampType | DayTimeIntervalType => v.isInstanceOf[Long]
case FloatType => v.isInstanceOf[Float]
case DoubleType => v.isInstanceOf[Double]
Expand Down Expand Up @@ -366,7 +369,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
ExprCode.forNonNullValue(JavaCode.literal(code, dataType))
}
dataType match {
case BooleanType | IntegerType | DateType =>
case BooleanType | IntegerType | DateType | YearMonthIntervalType =>
toExprCode(value.toString)
case FloatType =>
value.asInstanceOf[Float] match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.util

import java.time.Duration
import java.time.{Duration, Period}
import java.time.temporal.ChronoUnit
import java.util.concurrent.TimeUnit

Expand Down Expand Up @@ -791,4 +791,35 @@ object IntervalUtils {
* @return A [[Duration]], not null
*/
def microsToDuration(micros: Long): Duration = Duration.of(micros, ChronoUnit.MICROS)

/**
* Gets the total number of months in this period.
* <p>
* This returns the total number of months in the period by multiplying the
* number of years by 12 and adding the number of months.
* <p>
*
* @return The total number of months in the period, may be negative
* @throws ArithmeticException If numeric overflow occurs
*/
def periodToMonths(period: Period): Int = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we fail if the day field is not 0? Or at least give a warning?

Copy link
Member Author

@MaxGekk MaxGekk Mar 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't fail when we convert:

  1. java.sql.Date has time component with millisecond precision but we ignore it when we convert to days at
    val julianDays = Math.toIntExact(Math.floorDiv(millisLocal, MILLIS_PER_DAY))
  2. java.sql.Timestamp which has nanoseconds precision:
    val micros = millisToMicros(t.getTime) + (t.getNanos / NANOS_PER_MICROS) % MICROS_PER_MILLIS
  3. java.time.Instant which contains nanoseconds, and we don't fail when we convert it to microseconds:
    val result = Math.addExact(us, NANOSECONDS.toMicros(instant.getNano))

To be consistent with current implementation for other types, I do believe we should not fail.

Or at least give a warning?

This will just fill in the logs by useless records, and this is again inconsistent with current implementation.

val monthsInYears = Math.multiplyExact(period.getYears, MONTHS_PER_YEAR)
Math.addExact(monthsInYears, period.getMonths)
}

/**
* Obtains a [[Period]] representing a number of months. The days unit will be zero, and the years
* and months units will be normalized.
*
* <p>
* The months unit is adjusted to have an absolute value < 12, with the years unit being adjusted
* to compensate. For example, the method returns "2 years and 3 months" for the 27 input months.
* <p>
* The sign of the years and months units will be the same after normalization.
* For example, -13 months will be converted to "-1 year and -1 month".
*
* @param months The number of months, positive or negative
* @return The period of months, not null
*/
def monthsToPeriod(months: Int): Period = Period.ofMonths(months).normalized()
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ object DataType {
private val otherTypes = {
Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType,
DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType,
DayTimeIntervalType)
DayTimeIntervalType, YearMonthIntervalType)
.map(t => t.typeName -> t).toMap
}

Expand Down
Loading