Skip to content

Commit

Permalink
fixed tests and added type check
Browse files Browse the repository at this point in the history
  • Loading branch information
tarekbecker committed Jun 24, 2015
1 parent 5ebb235 commit 4d8049b
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,27 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.Date
import java.text.SimpleDateFormat

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

case class DateFormatClass(left: Expression, right: Expression)
extends BinaryExpression with ExpectsInputTypes {
case class DateFormatClass(left: Expression, right: Expression) extends BinaryExpression {

override def dataType: DataType = StringType

override def expectedChildTypes: Seq[DataType] = Seq(TimestampType, StringType)
override def checkInputDataTypes(): TypeCheckResult =
(left.dataType, right.dataType) match {
case (null, _) => TypeCheckResult.TypeCheckSuccess
case (_, null) => TypeCheckResult.TypeCheckSuccess
case (_: DateType, _: StringType) => TypeCheckResult.TypeCheckSuccess
case (_: TimestampType, _: StringType) => TypeCheckResult.TypeCheckSuccess
case (_: StringType, _: StringType) => TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(s"DateFormat accepts date types as first argument, " +
s"and string types as second, not ${left.dataType} and ${right.dataType}")
}

override def foldable: Boolean = left.foldable && right.foldable

Expand Down Expand Up @@ -71,7 +81,7 @@ case class DateFormatClass(left: Expression, right: Expression)

val calc = left.dataType match {
case TimestampType =>
s"""$utf8.fromString(sdf.format(new java.sql.Date(${eval1.primitive} / 10000)));"""
s""""$utf8.fromString(sdf.format(new java.sql.Date(${eval1.primitive} / 10000)));"""
case DateType =>
s"""$utf8.fromString(
sdf.format($dtUtils.toJavaDate(${eval1.primitive})));"""
Expand Down Expand Up @@ -115,6 +125,17 @@ case class Year(child: Expression) extends UnaryExpression with ExpectsInputType
}
}

override def checkInputDataTypes(): TypeCheckResult =
child.dataType match {
case null => TypeCheckResult.TypeCheckSuccess
case _: DateType => TypeCheckResult.TypeCheckSuccess
case _: TimestampType => TypeCheckResult.TypeCheckSuccess
case _: StringType => TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(s"Year accepts date types as argument, " +
s" not ${child.dataType}")
}

}

case class Month(child: Expression) extends UnaryExpression with ExpectsInputTypes {
Expand All @@ -133,6 +154,17 @@ case class Month(child: Expression) extends UnaryExpression with ExpectsInputTyp
case x: UTF8String => x.toString.toInt
}
}

override def checkInputDataTypes(): TypeCheckResult =
child.dataType match {
case null => TypeCheckResult.TypeCheckSuccess
case _: DateType => TypeCheckResult.TypeCheckSuccess
case _: TimestampType => TypeCheckResult.TypeCheckSuccess
case _: StringType => TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(s"Month accepts date types as argument, " +
s" not ${child.dataType}")
}
}

case class Day(child: Expression) extends UnaryExpression with ExpectsInputTypes {
Expand All @@ -152,14 +184,23 @@ case class Day(child: Expression) extends UnaryExpression with ExpectsInputTypes
}
}

override def checkInputDataTypes(): TypeCheckResult =
child.dataType match {
case null => TypeCheckResult.TypeCheckSuccess
case _: DateType => TypeCheckResult.TypeCheckSuccess
case _: TimestampType => TypeCheckResult.TypeCheckSuccess
case _: StringType => TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(s"Day accepts date types as argument, " +
s" not ${child.dataType}")
}

}

case class Hour(child: Expression) extends UnaryExpression with ExpectsInputTypes {
case class Hour(child: Expression) extends UnaryExpression {

override def dataType: DataType = IntegerType

override def expectedChildTypes: Seq[DataType] = Seq(DateType, StringType, TimestampType)

override def foldable: Boolean = child.foldable

override def nullable: Boolean = true
Expand All @@ -170,14 +211,23 @@ case class Hour(child: Expression) extends UnaryExpression with ExpectsInputType
case x: UTF8String => x.toString.toInt
}
}

override def checkInputDataTypes(): TypeCheckResult =
child.dataType match {
case null => TypeCheckResult.TypeCheckSuccess
case _: DateType => TypeCheckResult.TypeCheckSuccess
case _: TimestampType => TypeCheckResult.TypeCheckSuccess
case _: StringType => TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(s"Hour accepts date types as argument, " +
s" not ${child.dataType}")
}
}

case class Minute(child: Expression) extends UnaryExpression with ExpectsInputTypes {
case class Minute(child: Expression) extends UnaryExpression {

override def dataType: DataType = IntegerType

override def expectedChildTypes: Seq[DataType] = Seq(DateType, StringType, TimestampType)

override def foldable: Boolean = child.foldable

override def nullable: Boolean = true
Expand All @@ -188,14 +238,23 @@ case class Minute(child: Expression) extends UnaryExpression with ExpectsInputTy
case x: UTF8String => x.toString.toInt
}
}

override def checkInputDataTypes(): TypeCheckResult =
child.dataType match {
case null => TypeCheckResult.TypeCheckSuccess
case _: DateType => TypeCheckResult.TypeCheckSuccess
case _: TimestampType => TypeCheckResult.TypeCheckSuccess
case _: StringType => TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(s"Minute accepts date types as argument, " +
s" not ${child.dataType}")
}
}

case class Second(child: Expression) extends UnaryExpression with ExpectsInputTypes {
case class Second(child: Expression) extends UnaryExpression {

override def dataType: DataType = IntegerType

override def expectedChildTypes: Seq[DataType] = Seq(DateType, StringType, TimestampType)

override def foldable: Boolean = child.foldable

override def nullable: Boolean = true
Expand All @@ -206,14 +265,23 @@ case class Second(child: Expression) extends UnaryExpression with ExpectsInputTy
case x: UTF8String => x.toString.toInt
}
}

override def checkInputDataTypes(): TypeCheckResult =
child.dataType match {
case null => TypeCheckResult.TypeCheckSuccess
case _: DateType => TypeCheckResult.TypeCheckSuccess
case _: TimestampType => TypeCheckResult.TypeCheckSuccess
case _: StringType => TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(s"Second accepts date types as argument, " +
s" not ${child.dataType}")
}
}

case class WeekOfYear(child: Expression) extends UnaryExpression with ExpectsInputTypes {
case class WeekOfYear(child: Expression) extends UnaryExpression {

override def dataType: DataType = IntegerType

override def expectedChildTypes: Seq[DataType] = Seq(DateType, StringType, TimestampType)

override def foldable: Boolean = child.foldable

override def nullable: Boolean = true
Expand All @@ -225,5 +293,14 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ExpectsInp
}
}


override def checkInputDataTypes(): TypeCheckResult =
child.dataType match {
case null => TypeCheckResult.TypeCheckSuccess
case _: DateType => TypeCheckResult.TypeCheckSuccess
case _: TimestampType => TypeCheckResult.TypeCheckSuccess
case _: StringType => TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(s"WeekOfYear accepts date types as argument, " +
s" not ${child.dataType}")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class DataFrameFunctionsSuite extends QueryTest {
Row("2015", "2015", "2013"))

checkAnswer(
df.selectExpr("dateFormat(a, y)", "dateFormat(b, y)", "dateFormat(c, y)"),
df.selectExpr("dateFormat(a, 'y')", "dateFormat(b, 'y')", "dateFormat(c, 'y')"),
Row("2015", "2015", "2013"))
}

Expand Down Expand Up @@ -271,11 +271,11 @@ class DataFrameFunctionsSuite extends QueryTest {
val df = Seq((d, d.toString, ts)).toDF("a", "b", "c")

checkAnswer(
df.select(hour("a"), hour("b"), hour("c")),
df.select(second("a"), second("b"), second("c")),
Row(0, 0, 15))

checkAnswer(
df.selectExpr("hour(a)", "hour(b)", "hour(c)"),
df.selectExpr("second(a)", "second(b)", "second(c)"),
Row(0, 0, 15))
}

Expand Down

0 comments on commit 4d8049b

Please sign in to comment.