Skip to content

Commit 20cc8b1

Browse files
committed
Simplify code
1 parent 43fe7f4 commit 20cc8b1

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistr
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.trees.UnaryLike
24+
import org.apache.spark.sql.catalyst.util.TypeUtils
2425
import org.apache.spark.sql.types._
2526

2627
@ExpressionDescription(
@@ -42,12 +43,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
4243
override def inputTypes: Seq[AbstractDataType] =
4344
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType))
4445

45-
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
46-
case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess
47-
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
48-
case other => TypeCheckResult.TypeCheckFailure(
49-
s"function average requires numeric or interval types, not ${other.catalogString}")
50-
}
46+
override def checkInputDataTypes(): TypeCheckResult =
47+
TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "average")
5148

5249
override def nullable: Boolean = true
5350

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.trees.UnaryLike
24+
import org.apache.spark.sql.catalyst.util.TypeUtils
2425
import org.apache.spark.sql.internal.SQLConf
2526
import org.apache.spark.sql.types._
2627

@@ -48,12 +49,8 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
4849
override def inputTypes: Seq[AbstractDataType] =
4950
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType))
5051

51-
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
52-
case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess
53-
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
54-
case other => TypeCheckResult.TypeCheckFailure(
55-
s"function sum requires numeric or interval types, not ${other.catalogString}")
56-
}
52+
override def checkInputDataTypes(): TypeCheckResult =
53+
TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "sum")
5754

5855
private lazy val resultType = child.dataType match {
5956
case DecimalType.Fixed(precision, scale) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ object TypeUtils {
6161
}
6262
}
6363

64+
def checkForAnsiIntervalOrNumericType(
65+
dt: DataType, funcName: String): TypeCheckResult = dt match {
66+
case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess
67+
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
68+
case other => TypeCheckResult.TypeCheckFailure(
69+
s"function $funcName requires numeric or interval types, not ${other.catalogString}")
70+
}
71+
6472
def getNumeric(t: DataType, exactNumericRequired: Boolean = false): Numeric[Any] = {
6573
if (exactNumericRequired) {
6674
t.asInstanceOf[NumericType].exactNumeric.asInstanceOf[Numeric[Any]]

0 commit comments

Comments
 (0)