Skip to content

Commit 4d8c7c6

Browse files
chenghao-inteldavies
authored andcommitted
[SPARK-10865] [SPARK-10866] [SQL] Fix bug of ceil/floor, which should returns long instead of the Double type
Floor & Ceiling function should returns Long type, rather than Double. Verified with MySQL & Hive. Author: Cheng Hao <hao.cheng@intel.com> Closes #8933 from chenghao-intel/ceiling.
1 parent 9b3e776 commit 4d8c7c6

File tree

3 files changed

+31
-11
lines changed

3 files changed

+31
-11
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ abstract class LeafMathExpression(c: Double, name: String)
5252
* @param f The math function.
5353
* @param name The short name of the function
5454
*/
55-
abstract class UnaryMathExpression(f: Double => Double, name: String)
55+
abstract class UnaryMathExpression(val f: Double => Double, name: String)
5656
extends UnaryExpression with Serializable with ImplicitCastInputTypes {
5757

5858
override def inputTypes: Seq[DataType] = Seq(DoubleType)
@@ -152,7 +152,16 @@ case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN"
152152

153153
case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT")
154154

155-
case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL")
155+
case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") {
156+
override def dataType: DataType = LongType
157+
protected override def nullSafeEval(input: Any): Any = {
158+
f(input.asInstanceOf[Double]).toLong
159+
}
160+
161+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
162+
defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
163+
}
164+
}
156165

157166
case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS")
158167

@@ -195,7 +204,16 @@ case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP")
195204

196205
case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1")
197206

198-
case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR")
207+
case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") {
208+
override def dataType: DataType = LongType
209+
protected override def nullSafeEval(input: Any): Any = {
210+
f(input.asInstanceOf[Double]).toLong
211+
}
212+
213+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
214+
defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
215+
}
216+
}
199217

200218
object Factorial {
201219

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
244244
}
245245

246246
test("ceil") {
247-
testUnary(Ceil, math.ceil)
247+
testUnary(Ceil, (d: Double) => math.ceil(d).toLong)
248248
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType)
249249
}
250250

251251
test("floor") {
252-
testUnary(Floor, math.floor)
252+
testUnary(Floor, (d: Double) => math.floor(d).toLong)
253253
checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType)
254254
}
255255

sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
3737
private lazy val nullDoubles =
3838
Seq(NullDoubles(1.0), NullDoubles(2.0), NullDoubles(3.0), NullDoubles(null)).toDF()
3939

40-
private def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T](
40+
private def testOneToOneMathFunction[
41+
@specialized(Int, Long, Float, Double) T,
42+
@specialized(Int, Long, Float, Double) U](
4143
c: Column => Column,
42-
f: T => T): Unit = {
44+
f: T => U): Unit = {
4345
checkAnswer(
4446
doubleData.select(c('a)),
4547
(1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T])))
@@ -165,10 +167,10 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
165167
}
166168

167169
test("ceil and ceiling") {
168-
testOneToOneMathFunction(ceil, math.ceil)
170+
testOneToOneMathFunction(ceil, (d: Double) => math.ceil(d).toLong)
169171
checkAnswer(
170172
sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
171-
Row(0.0, 1.0, 2.0))
173+
Row(0L, 1L, 2L))
172174
}
173175

174176
test("conv") {
@@ -184,7 +186,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
184186
}
185187

186188
test("floor") {
187-
testOneToOneMathFunction(floor, math.floor)
189+
testOneToOneMathFunction(floor, (d: Double) => math.floor(d).toLong)
188190
}
189191

190192
test("factorial") {
@@ -228,7 +230,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
228230
}
229231

230232
test("signum / sign") {
231-
testOneToOneMathFunction[Double](signum, math.signum)
233+
testOneToOneMathFunction[Double, Double](signum, math.signum)
232234

233235
checkAnswer(
234236
sql("SELECT sign(10), signum(-11)"),

0 commit comments

Comments
 (0)