diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetime.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetime.scala index d7f5f5b8b8e38..46d1c37ffcdfb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetime.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetime.scala @@ -20,17 +20,27 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Date import java.text.SimpleDateFormat +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -case class DateFormatClass(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { +case class DateFormatClass(left: Expression, right: Expression) extends BinaryExpression { override def dataType: DataType = StringType - override def expectedChildTypes: Seq[DataType] = Seq(TimestampType, StringType) + override def checkInputDataTypes(): TypeCheckResult = + (left.dataType, right.dataType) match { + case (null, _) => TypeCheckResult.TypeCheckSuccess + case (_, null) => TypeCheckResult.TypeCheckSuccess + case (_: DateType, _: StringType) => TypeCheckResult.TypeCheckSuccess + case (_: TimestampType, _: StringType) => TypeCheckResult.TypeCheckSuccess + case (_: StringType, _: StringType) => TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure(s"DateFormat accepts date types as first argument, " + + s"and string types as second, not ${left.dataType} and ${right.dataType}") + } override def foldable: Boolean = left.foldable && right.foldable @@ -71,7 +81,7 @@ case class DateFormatClass(left: Expression, right: Expression) val calc = left.dataType match { case TimestampType => - s"""$utf8.fromString(sdf.format(new java.sql.Date(${eval1.primitive} / 10000)));""" + s""""$utf8.fromString(sdf.format(new java.sql.Date(${eval1.primitive} / 10000)));""" case DateType => s"""$utf8.fromString( sdf.format($dtUtils.toJavaDate(${eval1.primitive})));""" @@ -115,6 +125,17 @@ case class Year(child: Expression) extends UnaryExpression with ExpectsInputType } } + override def checkInputDataTypes(): TypeCheckResult = + child.dataType match { + case null => TypeCheckResult.TypeCheckSuccess + case _: DateType => TypeCheckResult.TypeCheckSuccess + case _: TimestampType => TypeCheckResult.TypeCheckSuccess + case _: StringType => TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure(s"Year accepts date types as argument, " + + s" not ${child.dataType}") + } + } case class Month(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -133,6 +154,17 @@ case class Month(child: Expression) extends UnaryExpression with ExpectsInputTyp case x: UTF8String => x.toString.toInt } } + + override def checkInputDataTypes(): TypeCheckResult = + child.dataType match { + case null => TypeCheckResult.TypeCheckSuccess + case _: DateType => TypeCheckResult.TypeCheckSuccess + case _: TimestampType => TypeCheckResult.TypeCheckSuccess + case _: StringType => TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure(s"Month accepts date types as argument, " + + s" not ${child.dataType}") + } } case class Day(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -152,14 +184,23 @@ case class Day(child: Expression) extends UnaryExpression with ExpectsInputTypes } } + override def checkInputDataTypes(): TypeCheckResult = + child.dataType match { + case null => TypeCheckResult.TypeCheckSuccess + case _: DateType => TypeCheckResult.TypeCheckSuccess + case _: TimestampType => TypeCheckResult.TypeCheckSuccess + case _: StringType => TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure(s"Day accepts date types as argument, " + + s" not ${child.dataType}") + } + } -case class Hour(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Hour(child: Expression) extends UnaryExpression { override def dataType: DataType = IntegerType - override def expectedChildTypes: Seq[DataType] = Seq(DateType, StringType, TimestampType) - override def foldable: Boolean = child.foldable override def nullable: Boolean = true @@ -170,14 +211,23 @@ case class Hour(child: Expression) extends UnaryExpression with ExpectsInputType case x: UTF8String => x.toString.toInt } } + + override def checkInputDataTypes(): TypeCheckResult = + child.dataType match { + case null => TypeCheckResult.TypeCheckSuccess + case _: DateType => TypeCheckResult.TypeCheckSuccess + case _: TimestampType => TypeCheckResult.TypeCheckSuccess + case _: StringType => TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure(s"Hour accepts date types as argument, " + + s" not ${child.dataType}") + } } -case class Minute(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Minute(child: Expression) extends UnaryExpression { override def dataType: DataType = IntegerType - override def expectedChildTypes: Seq[DataType] = Seq(DateType, StringType, TimestampType) - override def foldable: Boolean = child.foldable override def nullable: Boolean = true @@ -188,14 +238,23 @@ case class Minute(child: Expression) extends UnaryExpression with ExpectsInputTy case x: UTF8String => x.toString.toInt } } + + override def checkInputDataTypes(): TypeCheckResult = + child.dataType match { + case null => TypeCheckResult.TypeCheckSuccess + case _: DateType => TypeCheckResult.TypeCheckSuccess + case _: TimestampType => TypeCheckResult.TypeCheckSuccess + case _: StringType => TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure(s"Minute accepts date types as argument, " + + s" not ${child.dataType}") + } } -case class Second(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Second(child: Expression) extends UnaryExpression { override def dataType: DataType = IntegerType - override def expectedChildTypes: Seq[DataType] = Seq(DateType, StringType, TimestampType) - override def foldable: Boolean = child.foldable override def nullable: Boolean = true @@ -206,14 +265,23 @@ case class Second(child: Expression) extends UnaryExpression with ExpectsInputTy case x: UTF8String => x.toString.toInt } } + + override def checkInputDataTypes(): TypeCheckResult = + child.dataType match { + case null => TypeCheckResult.TypeCheckSuccess + case _: DateType => TypeCheckResult.TypeCheckSuccess + case _: TimestampType => TypeCheckResult.TypeCheckSuccess + case _: StringType => TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure(s"Second accepts date types as argument, " + + s" not ${child.dataType}") + } } -case class WeekOfYear(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class WeekOfYear(child: Expression) extends UnaryExpression { override def dataType: DataType = IntegerType - override def expectedChildTypes: Seq[DataType] = Seq(DateType, StringType, TimestampType) - override def foldable: Boolean = child.foldable override def nullable: Boolean = true @@ -225,5 +293,14 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ExpectsInp } } - + override def checkInputDataTypes(): TypeCheckResult = + child.dataType match { + case null => TypeCheckResult.TypeCheckSuccess + case _: DateType => TypeCheckResult.TypeCheckSuccess + case _: TimestampType => TypeCheckResult.TypeCheckSuccess + case _: StringType => TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure(s"WeekOfYear accepts date types as argument, " + + s" not ${child.dataType}") + } } \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e0f4c648e55be..b4bb4987177a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -179,7 +179,7 @@ class DataFrameFunctionsSuite extends QueryTest { Row("2015", "2015", "2013")) checkAnswer( - df.selectExpr("dateFormat(a, y)", "dateFormat(b, y)", "dateFormat(c, y)"), + df.selectExpr("dateFormat(a, 'y')", "dateFormat(b, 'y')", "dateFormat(c, 'y')"), Row("2015", "2015", "2013")) } @@ -271,11 +271,11 @@ class DataFrameFunctionsSuite extends QueryTest { val df = Seq((d, d.toString, ts)).toDF("a", "b", "c") checkAnswer( - df.select(hour("a"), hour("b"), hour("c")), + df.select(second("a"), second("b"), second("c")), Row(0, 0, 15)) checkAnswer( - df.selectExpr("hour(a)", "hour(b)", "hour(c)"), + df.selectExpr("second(a)", "second(b)", "second(c)"), Row(0, 0, 15)) }