-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-22829] Add new built-in function date_trunc() #20015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1111,6 +1111,24 @@ def trunc(date, format): | |
return Column(sc._jvm.functions.trunc(_to_java_column(date), format)) | ||
|
||
|
||
@since(2.3) | ||
def date_trunc(format, timestamp): | ||
""" | ||
Returns timestamp truncated to the unit specified by the format. | ||
|
||
:param format: 'year', 'YYYY', 'yy', 'month', 'mon', 'mm', | ||
'DAY', 'DD', 'HOUR', 'MINUTE', 'SECOND', 'WEEK', 'QUARTER' | ||
|
||
|
||
>>> df = spark.createDataFrame([('1997-02-28',)], ['d']) | ||
|
||
>>> df.select(date_trunc('year', df.d).alias('year')).collect() | ||
[Row(year=datetime.datetime(1997, 1, 1, 0, 0))] | ||
>>> df.select(date_trunc('mon', df.d).alias('month')).collect() | ||
[Row(month=datetime.datetime(1997, 2, 1, 0, 0))] | ||
""" | ||
sc = SparkContext._active_spark_context | ||
return Column(sc._jvm.functions.date_trunc(format, _to_java_column(timestamp))) | ||
|
||
|
||
@since(1.5) | ||
def next_day(date, dayOfWeek): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1295,87 +1295,184 @@ case class ParseToTimestamp(left: Expression, format: Option[Expression], child: | |
override def dataType: DataType = TimestampType | ||
} | ||
|
||
/** | ||
* Returns date truncated to the unit specified by the format. | ||
*/ | ||
// scalastyle:off line.size.limit | ||
@ExpressionDescription( | ||
usage = "_FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`.", | ||
examples = """ | ||
Examples: | ||
> SELECT _FUNC_('2009-02-12', 'MM'); | ||
2009-02-01 | ||
> SELECT _FUNC_('2015-10-27', 'YEAR'); | ||
2015-01-01 | ||
""", | ||
since = "1.5.0") | ||
// scalastyle:on line.size.limit | ||
case class TruncDate(date: Expression, format: Expression) | ||
extends BinaryExpression with ImplicitCastInputTypes { | ||
override def left: Expression = date | ||
override def right: Expression = format | ||
|
||
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) | ||
override def dataType: DataType = DateType | ||
trait TruncTime extends BinaryExpression with ImplicitCastInputTypes { | ||
|
||
val time: Expression | ||
val format: Expression | ||
override def nullable: Boolean = true | ||
override def prettyName: String = "trunc" | ||
|
||
private lazy val truncLevel: Int = | ||
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) | ||
|
||
override def eval(input: InternalRow): Any = { | ||
/** | ||
* | ||
|
||
* @param input | ||
|
||
* @param maxLevel Maximum level that can be used for truncation (e.g MONTH for Date input) | ||
* @param truncFunc | ||
* @tparam T | ||
* @return | ||
*/ | ||
protected def evalHelper[T](input: InternalRow, maxLevel: Int)( | ||
truncFunc: (Any, Int) => T): Any = { | ||
|
||
val level = if (format.foldable) { | ||
truncLevel | ||
} else { | ||
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) | ||
} | ||
if (level == -1) { | ||
if (level == DateTimeUtils.TRUNC_INVALID || level > maxLevel) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
// unknown format | ||
null | ||
} else { | ||
val d = date.eval(input) | ||
val d = time.eval(input) | ||
|
||
if (d == null) { | ||
null | ||
} else { | ||
DateTimeUtils.truncDate(d.asInstanceOf[Int], level) | ||
truncFunc(d, level) | ||
} | ||
} | ||
} | ||
|
||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
protected def codeGenHelper[T]( | ||
|
||
ctx: CodegenContext, | ||
ev: ExprCode, | ||
maxLevel: Int, | ||
orderReversed: Boolean = false)( | ||
truncFunc: (String, String) => String) | ||
: ExprCode = { | ||
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") | ||
|
||
if (format.foldable) { | ||
if (truncLevel == -1) { | ||
if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) { | ||
ev.copy(code = s""" | ||
boolean ${ev.isNull} = true; | ||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") | ||
} else { | ||
val d = date.genCode(ctx) | ||
val d = time.genCode(ctx) | ||
val truncFuncStr = truncFunc(d.value, truncLevel.toString) | ||
ev.copy(code = s""" | ||
${d.code} | ||
boolean ${ev.isNull} = ${d.isNull}; | ||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; | ||
if (!${ev.isNull}) { | ||
${ev.value} = $dtu.truncDate(${d.value}, $truncLevel); | ||
${ev.value} = $dtu.$truncFuncStr; | ||
}""") | ||
} | ||
} else { | ||
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { | ||
nullSafeCodeGen(ctx, ev, (left, right) => { | ||
val form = ctx.freshName("form") | ||
val (dateVal, fmt) = if (orderReversed) { | ||
(right, left) | ||
} else { | ||
(left, right) | ||
} | ||
val truncFuncStr = truncFunc(dateVal, form) | ||
s""" | ||
int $form = $dtu.parseTruncLevel($fmt); | ||
if ($form == -1) { | ||
if ($form == -1 || $form > $maxLevel) { | ||
${ev.isNull} = true; | ||
} else { | ||
${ev.value} = $dtu.truncDate($dateVal, $form); | ||
${ev.value} = $dtu.$truncFuncStr | ||
} | ||
""" | ||
}) | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Returns date truncated to the unit specified by the format. | ||
*/ | ||
// scalastyle:off line.size.limit | ||
@ExpressionDescription( | ||
usage = """ | ||
_FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`. | ||
`fmt` should be one of ["YEAR", "YYYY", "YY", "MON", "MONTH", "MM"] | ||
|
||
""", | ||
examples = """ | ||
Examples: | ||
> SELECT _FUNC_('2009-02-12', 'MM'); | ||
2009-02-01 | ||
> SELECT _FUNC_('2015-10-27', 'YEAR'); | ||
2015-01-01 | ||
""", | ||
since = "1.5.0") | ||
// scalastyle:on line.size.limit | ||
case class TruncDate(date: Expression, format: Expression) | ||
extends TruncTime { | ||
override def left: Expression = date | ||
override def right: Expression = format | ||
|
||
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) | ||
override def dataType: DataType = DateType | ||
override def prettyName: String = "trunc" | ||
override val time = date | ||
|
||
override def eval(input: InternalRow): Any = { | ||
evalHelper(input, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { (d: Any, level: Int) => | ||
DateTimeUtils.truncDate(d.asInstanceOf[Int], level) | ||
} | ||
} | ||
|
||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { (date: String, fmt: String) => | ||
s"truncDate($date, $fmt);" | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Returns timestamp truncated to the unit specified by the format. | ||
*/ | ||
// scalastyle:off line.size.limit | ||
@ExpressionDescription( | ||
usage = """ | ||
_FUNC_(fmt, date) - Returns timestamp `ts` truncated to the unit specified by the format model `fmt`. | ||
|
||
`fmt` should be one of ["YEAR", "YYYY", "YY", "MON", "MONTH", "MM", "DAY", "DD", "HOUR", "MINUTE", "SECOND", "WEEK", "QUARTER"] | ||
""", | ||
examples = """ | ||
Examples: | ||
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'YEAR'); | ||
2015-01-01T00:00:00 | ||
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'MM'); | ||
2015-03-01T00:00:00 | ||
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'DD'); | ||
2015-03-05T00:00:00 | ||
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'HOUR'); | ||
2015-03-05T09:00:00 | ||
""", | ||
since = "2.3.0") | ||
// scalastyle:on line.size.limit | ||
case class TruncTimestamp( | ||
format: Expression, | ||
timestamp: Expression, | ||
timeZoneId: Option[String] = None) | ||
extends TruncTime with TimeZoneAwareExpression { | ||
override def left: Expression = format | ||
override def right: Expression = timestamp | ||
|
||
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType) | ||
override def dataType: TimestampType = TimestampType | ||
override def prettyName: String = "date_trunc" | ||
override val time = timestamp | ||
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = | ||
copy(timeZoneId = Option(timeZoneId)) | ||
|
||
def this(format: Expression, timestamp: Expression) = this(format, timestamp, None) | ||
|
||
override def eval(input: InternalRow): Any = { | ||
evalHelper(input, maxLevel = DateTimeUtils.TRUNC_TO_QUARTER) { (d: Any, level: Int) => | ||
DateTimeUtils.truncTimestamp(d.asInstanceOf[Long], level, timeZone) | ||
} | ||
} | ||
|
||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
val tz = ctx.addReferenceObj("timeZone", timeZone) | ||
codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_QUARTER, true) { | ||
(date: String, fmt: String) => | ||
s"truncTimestamp($date, $fmt, $tz);" | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Returns the number of days from startDate to endDate. | ||
*/ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,7 +45,8 @@ object DateTimeUtils { | |
// it's 2440587.5, rounding up to compatible with Hive | ||
final val JULIAN_DAY_OF_EPOCH = 2440588 | ||
final val SECONDS_PER_DAY = 60 * 60 * 24L | ||
final val MICROS_PER_SECOND = 1000L * 1000L | ||
final val MICROS_PER_MILLIS = 1000L | ||
final val MICROS_PER_SECOND = MICROS_PER_MILLIS * MILLIS_PER_SECOND | ||
final val MILLIS_PER_SECOND = 1000L | ||
final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L | ||
final val MICROS_PER_DAY = MICROS_PER_SECOND * SECONDS_PER_DAY | ||
|
@@ -909,20 +910,29 @@ object DateTimeUtils { | |
math.round(diff * 1e8) / 1e8 | ||
} | ||
|
||
// Thursday = 0 since 1970/Jan/01 => Thursday | ||
private val SUNDAY = 3 | ||
private val MONDAY = 4 | ||
private val TUESDAY = 5 | ||
private val WEDNESDAY = 6 | ||
private val THURSDAY = 0 | ||
private val FRIDAY = 1 | ||
private val SATURDAY = 2 | ||
|
||
/* | ||
* Returns day of week from String. Starting from Thursday, marked as 0. | ||
* (Because 1970-01-01 is Thursday). | ||
*/ | ||
def getDayOfWeekFromString(string: UTF8String): Int = { | ||
val dowString = string.toString.toUpperCase(Locale.ROOT) | ||
dowString match { | ||
case "SU" | "SUN" | "SUNDAY" => 3 | ||
case "MO" | "MON" | "MONDAY" => 4 | ||
case "TU" | "TUE" | "TUESDAY" => 5 | ||
case "WE" | "WED" | "WEDNESDAY" => 6 | ||
case "TH" | "THU" | "THURSDAY" => 0 | ||
case "FR" | "FRI" | "FRIDAY" => 1 | ||
case "SA" | "SAT" | "SATURDAY" => 2 | ||
case "SU" | "SUN" | "SUNDAY" => SUNDAY | ||
case "MO" | "MON" | "MONDAY" => MONDAY | ||
case "TU" | "TUE" | "TUESDAY" => TUESDAY | ||
case "WE" | "WED" | "WEDNESDAY" => WEDNESDAY | ||
case "TH" | "THU" | "THURSDAY" => THURSDAY | ||
case "FR" | "FRI" | "FRIDAY" => FRIDAY | ||
case "SA" | "SAT" | "SATURDAY" => SATURDAY | ||
case _ => -1 | ||
} | ||
} | ||
|
@@ -944,9 +954,16 @@ object DateTimeUtils { | |
date + daysToMonthEnd | ||
} | ||
|
||
private val TRUNC_TO_YEAR = 1 | ||
private val TRUNC_TO_MONTH = 2 | ||
private val TRUNC_INVALID = -1 | ||
// Visible for testing. | ||
val TRUNC_TO_YEAR = 1 | ||
val TRUNC_TO_MONTH = 2 | ||
val TRUNC_TO_DAY = 3 | ||
val TRUNC_TO_HOUR = 4 | ||
val TRUNC_TO_MINUTE = 5 | ||
val TRUNC_TO_SECOND = 6 | ||
val TRUNC_TO_WEEK = 7 | ||
val TRUNC_TO_QUARTER = 8 | ||
val TRUNC_INVALID = -1 | ||
|
||
|
||
/** | ||
* Returns the trunc date from original date and trunc level. | ||
|
@@ -964,7 +981,62 @@ object DateTimeUtils { | |
} | ||
|
||
/** | ||
* Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, or TRUNC_INVALID, | ||
* Returns the trunc date time from original date time and trunc level. | ||
* Trunc level should be generated using `parseTruncLevel()`, should be between 1 and 8 | ||
*/ | ||
def truncTimestamp(d: SQLTimestamp, level: Int, timeZone: TimeZone): SQLTimestamp = { | ||
|
||
var millis = d / MICROS_PER_MILLIS | ||
val truncated = level match { | ||
case TRUNC_TO_YEAR => | ||
val dDays = millisToDays(millis, timeZone) | ||
daysToMillis(truncDate(dDays, level), timeZone) | ||
case TRUNC_TO_MONTH => | ||
val dDays = millisToDays(millis, timeZone) | ||
daysToMillis(truncDate(dDays, level), timeZone) | ||
case TRUNC_TO_DAY => | ||
val offset = timeZone.getOffset(millis) | ||
millis += offset | ||
millis - millis % (MILLIS_PER_SECOND * SECONDS_PER_DAY) - offset | ||
case TRUNC_TO_HOUR => | ||
val offset = timeZone.getOffset(millis) | ||
millis += offset | ||
millis - millis % (60 * 60 * MILLIS_PER_SECOND) - offset | ||
case TRUNC_TO_MINUTE => | ||
millis - millis % (60 * MILLIS_PER_SECOND) | ||
case TRUNC_TO_SECOND => | ||
millis - millis % MILLIS_PER_SECOND | ||
case TRUNC_TO_WEEK => | ||
val dDays = millisToDays(millis, timeZone) | ||
val prevMonday = getNextDateForDayOfWeek(dDays - 7, MONDAY) | ||
daysToMillis(prevMonday, timeZone) | ||
case TRUNC_TO_QUARTER => | ||
val dDays = millisToDays(millis, timeZone) | ||
millis = daysToMillis(truncDate(dDays, TRUNC_TO_MONTH), timeZone) | ||
val cal = Calendar.getInstance() | ||
cal.setTimeInMillis(millis) | ||
val quarter = getQuarter(dDays) | ||
val month = quarter match { | ||
case 1 => Calendar.JANUARY | ||
case 2 => Calendar.APRIL | ||
case 3 => Calendar.JULY | ||
case 4 => Calendar.OCTOBER | ||
} | ||
cal.set(Calendar.MONTH, month) | ||
cal.getTimeInMillis() | ||
case _ => | ||
// caller make sure that this should never be reached | ||
sys.error(s"Invalid trunc level: $level") | ||
} | ||
truncated * MICROS_PER_MILLIS | ||
} | ||
|
||
def truncTimestamp(d: SQLTimestamp, level: Int): SQLTimestamp = { | ||
truncTimestamp(d, level, defaultTimeZone()) | ||
} | ||
|
||
/** | ||
* Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, TRUNC_TO_DAY, TRUNC_TO_HOUR, | ||
* TRUNC_TO_MINUTE, TRUNC_TO_SECOND, TRUNC_TO_WEEK, TRUNC_TO_QUARTER or TRUNC_INVALID, | ||
* TRUNC_INVALID means unsupported truncate level. | ||
*/ | ||
def parseTruncLevel(format: UTF8String): Int = { | ||
|
@@ -974,6 +1046,12 @@ object DateTimeUtils { | |
format.toString.toUpperCase(Locale.ROOT) match { | ||
case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR | ||
case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH | ||
case "DAY" | "DD" => TRUNC_TO_DAY | ||
case "HOUR" => TRUNC_TO_HOUR | ||
case "MINUTE" => TRUNC_TO_MINUTE | ||
case "SECOND" => TRUNC_TO_SECOND | ||
case "WEEK" => TRUNC_TO_WEEK | ||
case "QUARTER" => TRUNC_TO_QUARTER | ||
case _ => TRUNC_INVALID | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
YYYY
->yyyy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also update the original
trunc