Skip to content

Commit 6fc66e8

Browse files
committed
Fix:If the input parameter is float type for ceil or floor,the result is not
we expected
1 parent 287440d commit 6fc66e8

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,18 +232,20 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL"
232232
}
233233

234234
override def inputTypes: Seq[AbstractDataType] =
235-
Seq(TypeCollection(DoubleType, DecimalType))
235+
Seq(TypeCollection(DoubleType, DecimalType, LongType))
236236

237237
protected override def nullSafeEval(input: Any): Any = child.dataType match {
238+
case LongType => input.asInstanceOf[Long]
238239
case DoubleType => f(input.asInstanceOf[Double]).toLong
239-
case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil
240+
case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].ceil
240241
}
241242

242243
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
243244
child.dataType match {
244245
case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
245-
case DecimalType.Fixed(precision, scale) =>
246+
case DecimalType.Fixed(_, _) =>
246247
defineCodeGen(ctx, ev, c => s"$c.ceil()")
248+
case LongType => defineCodeGen(ctx, ev, c => s"$c")
247249
case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
248250
}
249251
}
@@ -347,18 +349,20 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO
347349
}
348350

349351
override def inputTypes: Seq[AbstractDataType] =
350-
Seq(TypeCollection(DoubleType, DecimalType))
352+
Seq(TypeCollection(DoubleType, DecimalType, LongType))
351353

352354
protected override def nullSafeEval(input: Any): Any = child.dataType match {
355+
case LongType => input.asInstanceOf[Long]
353356
case DoubleType => f(input.asInstanceOf[Double]).toLong
354-
case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor
357+
case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].floor
355358
}
356359

357360
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
358361
child.dataType match {
359362
case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
360-
case DecimalType.Fixed(precision, scale) =>
363+
case DecimalType.Fixed(_, _) =>
361364
defineCodeGen(ctx, ev, c => s"$c.floor()")
365+
case LongType => defineCodeGen(ctx, ev, c => s"$c")
362366
case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
363367
}
364368
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
252252
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3))
253253
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0))
254254
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0))
255+
256+
val doublePi: Double = 3.1415
257+
val floatPi: Float = 3.1415f
258+
val longLit: Long = 12345678901234567L
259+
checkEvaluation(Ceil(doublePi), 4L, EmptyRow)
260+
checkEvaluation(Ceil(floatPi.toDouble), 4L, EmptyRow)
261+
checkEvaluation(Ceil(longLit), longLit, EmptyRow)
262+
checkEvaluation(Ceil(-doublePi), -3L, EmptyRow)
263+
checkEvaluation(Ceil(-floatPi.toDouble), -3L, EmptyRow)
264+
checkEvaluation(Ceil(-longLit), -longLit, EmptyRow)
255265
}
256266

257267
test("floor") {
@@ -262,6 +272,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
262272
checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3))
263273
checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0))
264274
checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0))
275+
276+
val doublePi: Double = 3.1415
277+
val floatPi: Float = 3.1415f
278+
val longLit: Long = 12345678901234567L
279+
checkEvaluation(Floor(doublePi), 3L, EmptyRow)
280+
checkEvaluation(Floor(floatPi.toDouble), 3L, EmptyRow)
281+
checkEvaluation(Floor(longLit), longLit, EmptyRow)
282+
checkEvaluation(Floor(-doublePi), -4L, EmptyRow)
283+
checkEvaluation(Floor(-floatPi.toDouble), -4L, EmptyRow)
284+
checkEvaluation(Floor(-longLit), -longLit, EmptyRow)
265285
}
266286

267287
test("factorial") {

0 commit comments

Comments
 (0)