Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,24 @@ def trunc(date, format):
return Column(sc._jvm.functions.trunc(_to_java_column(date), format))


@since(2.3)
def date_trunc(format, timestamp):
"""
Returns timestamp truncated to the unit specified by the format.

:param format: 'year', 'YYYY', 'yy', 'month', 'mon', 'mm',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: YYYY -> yyyy

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also update the original trunc

'DAY', 'DD', 'HOUR', 'MINUTE', 'SECOND', 'WEEK', 'QUARTER'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make those lowercased too?


>>> df = spark.createDataFrame([('1997-02-28',)], ['d'])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a timestamp string like 1997-02-28 05:02:11 to show the difference from trunc a bit more clearly?

>>> df.select(date_trunc('year', df.d).alias('year')).collect()
[Row(year=datetime.datetime(1997, 1, 1, 0, 0))]
>>> df.select(date_trunc('mon', df.d).alias('month')).collect()
[Row(month=datetime.datetime(1997, 2, 1, 0, 0))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.date_trunc(format, _to_java_column(timestamp)))


@since(1.5)
def next_day(date, dayOfWeek):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ object FunctionRegistry {
expression[ToUnixTimestamp]("to_unix_timestamp"),
expression[ToUTCTimestamp]("to_utc_timestamp"),
expression[TruncDate]("trunc"),
expression[TruncTimestamp]("date_trunc"),
expression[UnixTimestamp]("unix_timestamp"),
expression[DayOfWeek]("dayofweek"),
expression[WeekOfYear]("weekofyear"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1295,87 +1295,184 @@ case class ParseToTimestamp(left: Expression, format: Option[Expression], child:
override def dataType: DataType = TimestampType
}

/**
* Returns date truncated to the unit specified by the format.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`.",
examples = """
Examples:
> SELECT _FUNC_('2009-02-12', 'MM');
2009-02-01
> SELECT _FUNC_('2015-10-27', 'YEAR');
2015-01-01
""",
since = "1.5.0")
// scalastyle:on line.size.limit
case class TruncDate(date: Expression, format: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = date
override def right: Expression = format

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
override def dataType: DataType = DateType
trait TruncTime extends BinaryExpression with ImplicitCastInputTypes {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe TruncInstant? I received this advice before and I liked it too. Not a big deal tho.

val time: Expression
val format: Expression
override def nullable: Boolean = true
override def prettyName: String = "trunc"

private lazy val truncLevel: Int =
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])

override def eval(input: InternalRow): Any = {
/**
*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line.

* @param input
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems input and truncFunc descriptions missing.

* @param maxLevel Maximum level that can be used for truncation (e.g MONTH for Date input)
* @param truncFunc
* @tparam T
* @return
*/
protected def evalHelper[T](input: InternalRow, maxLevel: Int)(
truncFunc: (Any, Int) => T): Any = {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe truncFunc: (Any, Int) => Any is enough? So we don't need to use the T, but I'm not sure if this is better...

val level = if (format.foldable) {
truncLevel
} else {
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
}
if (level == -1) {
if (level == DateTimeUtils.TRUNC_INVALID || level > maxLevel) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// unknown format or too small level?

// unknown format
null
} else {
val d = date.eval(input)
val d = time.eval(input)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Since this is a time, it can be val t = ...

if (d == null) {
null
} else {
DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
truncFunc(d, level)
}
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
protected def codeGenHelper[T](
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a type parameter T?

ctx: CodegenContext,
ev: ExprCode,
maxLevel: Int,
orderReversed: Boolean = false)(
truncFunc: (String, String) => String)
: ExprCode = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")

if (format.foldable) {
if (truncLevel == -1) {
if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
} else {
val d = date.genCode(ctx)
val d = time.genCode(ctx)
val truncFuncStr = truncFunc(d.value, truncLevel.toString)
ev.copy(code = s"""
${d.code}
boolean ${ev.isNull} = ${d.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $dtu.truncDate(${d.value}, $truncLevel);
${ev.value} = $dtu.$truncFuncStr;
}""")
}
} else {
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
nullSafeCodeGen(ctx, ev, (left, right) => {
val form = ctx.freshName("form")
val (dateVal, fmt) = if (orderReversed) {
(right, left)
} else {
(left, right)
}
val truncFuncStr = truncFunc(dateVal, form)
s"""
int $form = $dtu.parseTruncLevel($fmt);
if ($form == -1) {
if ($form == -1 || $form > $maxLevel) {
${ev.isNull} = true;
} else {
${ev.value} = $dtu.truncDate($dateVal, $form);
${ev.value} = $dtu.$truncFuncStr
}
"""
})
}
}
}

/**
* Returns date truncated to the unit specified by the format.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`.
`fmt` should be one of ["YEAR", "YYYY", "YY", "MON", "MONTH", "MM"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us use the lower case and also update the other functions in this file. For example, ToUnixTimestamp

""",
examples = """
Examples:
> SELECT _FUNC_('2009-02-12', 'MM');
2009-02-01
> SELECT _FUNC_('2015-10-27', 'YEAR');
2015-01-01
""",
since = "1.5.0")
// scalastyle:on line.size.limit
case class TruncDate(date: Expression, format: Expression)
extends TruncTime {
override def left: Expression = date
override def right: Expression = format

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
override def dataType: DataType = DateType
override def prettyName: String = "trunc"
override val time = date

override def eval(input: InternalRow): Any = {
evalHelper(input, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { (d: Any, level: Int) =>
DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { (date: String, fmt: String) =>
s"truncDate($date, $fmt);"
}
}
}

/**
* Returns timestamp truncated to the unit specified by the format.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(fmt, date) - Returns timestamp `ts` truncated to the unit specified by the format model `fmt`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

date -> ts.

`fmt` should be one of ["YEAR", "YYYY", "YY", "MON", "MONTH", "MM", "DAY", "DD", "HOUR", "MINUTE", "SECOND", "WEEK", "QUARTER"]
""",
examples = """
Examples:
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'YEAR');
2015-01-01T00:00:00
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'MM');
2015-03-01T00:00:00
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'DD');
2015-03-05T00:00:00
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'HOUR');
2015-03-05T09:00:00
""",
since = "2.3.0")
// scalastyle:on line.size.limit
case class TruncTimestamp(
format: Expression,
timestamp: Expression,
timeZoneId: Option[String] = None)
extends TruncTime with TimeZoneAwareExpression {
override def left: Expression = format
override def right: Expression = timestamp

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType)
override def dataType: TimestampType = TimestampType
override def prettyName: String = "date_trunc"
override val time = timestamp
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

def this(format: Expression, timestamp: Expression) = this(format, timestamp, None)

override def eval(input: InternalRow): Any = {
evalHelper(input, maxLevel = DateTimeUtils.TRUNC_TO_QUARTER) { (d: Any, level: Int) =>
DateTimeUtils.truncTimestamp(d.asInstanceOf[Long], level, timeZone)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val tz = ctx.addReferenceObj("timeZone", timeZone)
codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_QUARTER, true) {
(date: String, fmt: String) =>
s"truncTimestamp($date, $fmt, $tz);"
}
}
}

/**
* Returns the number of days from startDate to endDate.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ object DateTimeUtils {
// it's 2440587.5, rounding up to compatible with Hive
final val JULIAN_DAY_OF_EPOCH = 2440588
final val SECONDS_PER_DAY = 60 * 60 * 24L
final val MICROS_PER_SECOND = 1000L * 1000L
final val MICROS_PER_MILLIS = 1000L
final val MICROS_PER_SECOND = MICROS_PER_MILLIS * MILLIS_PER_SECOND
final val MILLIS_PER_SECOND = 1000L
final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L
final val MICROS_PER_DAY = MICROS_PER_SECOND * SECONDS_PER_DAY
Expand Down Expand Up @@ -909,20 +910,29 @@ object DateTimeUtils {
math.round(diff * 1e8) / 1e8
}

// Thursday = 0 since 1970/Jan/01 => Thursday
private val SUNDAY = 3
private val MONDAY = 4
private val TUESDAY = 5
private val WEDNESDAY = 6
private val THURSDAY = 0
private val FRIDAY = 1
private val SATURDAY = 2

/*
* Returns day of week from String. Starting from Thursday, marked as 0.
* (Because 1970-01-01 is Thursday).
*/
def getDayOfWeekFromString(string: UTF8String): Int = {
val dowString = string.toString.toUpperCase(Locale.ROOT)
dowString match {
case "SU" | "SUN" | "SUNDAY" => 3
case "MO" | "MON" | "MONDAY" => 4
case "TU" | "TUE" | "TUESDAY" => 5
case "WE" | "WED" | "WEDNESDAY" => 6
case "TH" | "THU" | "THURSDAY" => 0
case "FR" | "FRI" | "FRIDAY" => 1
case "SA" | "SAT" | "SATURDAY" => 2
case "SU" | "SUN" | "SUNDAY" => SUNDAY
case "MO" | "MON" | "MONDAY" => MONDAY
case "TU" | "TUE" | "TUESDAY" => TUESDAY
case "WE" | "WED" | "WEDNESDAY" => WEDNESDAY
case "TH" | "THU" | "THURSDAY" => THURSDAY
case "FR" | "FRI" | "FRIDAY" => FRIDAY
case "SA" | "SAT" | "SATURDAY" => SATURDAY
case _ => -1
}
}
Expand All @@ -944,9 +954,16 @@ object DateTimeUtils {
date + daysToMonthEnd
}

private val TRUNC_TO_YEAR = 1
private val TRUNC_TO_MONTH = 2
private val TRUNC_INVALID = -1
// Visible for testing.
val TRUNC_TO_YEAR = 1
val TRUNC_TO_MONTH = 2
val TRUNC_TO_DAY = 3
val TRUNC_TO_HOUR = 4
val TRUNC_TO_MINUTE = 5
val TRUNC_TO_SECOND = 6
val TRUNC_TO_WEEK = 7
val TRUNC_TO_QUARTER = 8
val TRUNC_INVALID = -1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we bring quarter and week forward, maybe to 3 and 4? Then it's more conform to the order of time granularity and max-level design is not influenced.


/**
* Returns the trunc date from original date and trunc level.
Expand All @@ -964,7 +981,62 @@ object DateTimeUtils {
}

/**
* Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, or TRUNC_INVALID,
* Returns the trunc date time from original date time and trunc level.
* Trunc level should be generated using `parseTruncLevel()`, should be between 1 and 8
*/
def truncTimestamp(d: SQLTimestamp, level: Int, timeZone: TimeZone): SQLTimestamp = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: d -> ts or t

var millis = d / MICROS_PER_MILLIS
val truncated = level match {
case TRUNC_TO_YEAR =>
val dDays = millisToDays(millis, timeZone)
daysToMillis(truncDate(dDays, level), timeZone)
case TRUNC_TO_MONTH =>
val dDays = millisToDays(millis, timeZone)
daysToMillis(truncDate(dDays, level), timeZone)
case TRUNC_TO_DAY =>
val offset = timeZone.getOffset(millis)
millis += offset
millis - millis % (MILLIS_PER_SECOND * SECONDS_PER_DAY) - offset
case TRUNC_TO_HOUR =>
val offset = timeZone.getOffset(millis)
millis += offset
millis - millis % (60 * 60 * MILLIS_PER_SECOND) - offset
case TRUNC_TO_MINUTE =>
millis - millis % (60 * MILLIS_PER_SECOND)
case TRUNC_TO_SECOND =>
millis - millis % MILLIS_PER_SECOND
case TRUNC_TO_WEEK =>
val dDays = millisToDays(millis, timeZone)
val prevMonday = getNextDateForDayOfWeek(dDays - 7, MONDAY)
daysToMillis(prevMonday, timeZone)
case TRUNC_TO_QUARTER =>
val dDays = millisToDays(millis, timeZone)
millis = daysToMillis(truncDate(dDays, TRUNC_TO_MONTH), timeZone)
val cal = Calendar.getInstance()
cal.setTimeInMillis(millis)
val quarter = getQuarter(dDays)
val month = quarter match {
case 1 => Calendar.JANUARY
case 2 => Calendar.APRIL
case 3 => Calendar.JULY
case 4 => Calendar.OCTOBER
}
cal.set(Calendar.MONTH, month)
cal.getTimeInMillis()
case _ =>
// caller make sure that this should never be reached
sys.error(s"Invalid trunc level: $level")
}
truncated * MICROS_PER_MILLIS
}

def truncTimestamp(d: SQLTimestamp, level: Int): SQLTimestamp = {
truncTimestamp(d, level, defaultTimeZone())
}

/**
* Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, TRUNC_TO_DAY, TRUNC_TO_HOUR,
* TRUNC_TO_MINUTE, TRUNC_TO_SECOND, TRUNC_TO_WEEK, TRUNC_TO_QUARTER or TRUNC_INVALID,
* TRUNC_INVALID means unsupported truncate level.
*/
def parseTruncLevel(format: UTF8String): Int = {
Expand All @@ -974,6 +1046,12 @@ object DateTimeUtils {
format.toString.toUpperCase(Locale.ROOT) match {
case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR
case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH
case "DAY" | "DD" => TRUNC_TO_DAY
case "HOUR" => TRUNC_TO_HOUR
case "MINUTE" => TRUNC_TO_MINUTE
case "SECOND" => TRUNC_TO_SECOND
case "WEEK" => TRUNC_TO_WEEK
case "QUARTER" => TRUNC_TO_QUARTER
case _ => TRUNC_INVALID
}
}
Expand Down
Loading