Skip to content

Commit 3969a80

Browse files
10110346gatorsmile
authored andcommitted
[SPARK-20876][SQL] If the input parameter is float type for ceil or floor,the result is not we expected
## What changes were proposed in this pull request? spark-sql>SELECT ceil(cast(12345.1233 as float)); spark-sql>12345 For this case, the result we expected is `12346` spark-sql>SELECT floor(cast(-12345.1233 as float)); spark-sql>-12345 For this case, the result we expected is `-12346` Because in `Ceil` or `Floor`, `inputTypes` has no FloatType, so it is converted to LongType. ## How was this patch tested? After the modification: spark-sql>SELECT ceil(cast(12345.1233 as float)); spark-sql>12346 spark-sql>SELECT floor(cast(-12345.1233 as float)); spark-sql>-12346 Author: liuxian <liu.xian3@zte.com.cn> Closes #18103 from 10110346/wip-lx-0525-1.
1 parent 08ede46 commit 3969a80

File tree

5 files changed

+43
-48
lines changed

5 files changed

+43
-48
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,19 +232,20 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL"
232232
}
233233

234234
override def inputTypes: Seq[AbstractDataType] =
235-
Seq(TypeCollection(LongType, DoubleType, DecimalType))
235+
Seq(TypeCollection(DoubleType, DecimalType, LongType))
236236

237237
protected override def nullSafeEval(input: Any): Any = child.dataType match {
238238
case LongType => input.asInstanceOf[Long]
239239
case DoubleType => f(input.asInstanceOf[Double]).toLong
240-
case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil
240+
case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].ceil
241241
}
242242

243243
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
244244
child.dataType match {
245245
case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
246-
case DecimalType.Fixed(precision, scale) =>
246+
case DecimalType.Fixed(_, _) =>
247247
defineCodeGen(ctx, ev, c => s"$c.ceil()")
248+
case LongType => defineCodeGen(ctx, ev, c => s"$c")
248249
case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
249250
}
250251
}
@@ -348,19 +349,20 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO
348349
}
349350

350351
override def inputTypes: Seq[AbstractDataType] =
351-
Seq(TypeCollection(LongType, DoubleType, DecimalType))
352+
Seq(TypeCollection(DoubleType, DecimalType, LongType))
352353

353354
protected override def nullSafeEval(input: Any): Any = child.dataType match {
354355
case LongType => input.asInstanceOf[Long]
355356
case DoubleType => f(input.asInstanceOf[Double]).toLong
356-
case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor
357+
case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].floor
357358
}
358359

359360
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
360361
child.dataType match {
361362
case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
362-
case DecimalType.Fixed(precision, scale) =>
363+
case DecimalType.Fixed(_, _) =>
363364
defineCodeGen(ctx, ev, c => s"$c.floor()")
365+
case LongType => defineCodeGen(ctx, ev, c => s"$c")
364366
case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
365367
}
366368
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers {
262262

263263
val plan = testRelation2.select('c).orderBy(Floor('a).asc)
264264
val expected = testRelation2.select(c, a)
265-
.orderBy(Floor(Cast(a, LongType, Option(TimeZone.getDefault().getID))).asc).select(c)
265+
.orderBy(Floor(Cast(a, DoubleType, Option(TimeZone.getDefault().getID))).asc).select(c)
266266

267267
checkAnalysis(plan, expected)
268268
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
258258
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3))
259259
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0))
260260
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0))
261+
262+
val doublePi: Double = 3.1415
263+
val floatPi: Float = 3.1415f
264+
val longLit: Long = 12345678901234567L
265+
checkEvaluation(Ceil(doublePi), 4L, EmptyRow)
266+
checkEvaluation(Ceil(floatPi.toDouble), 4L, EmptyRow)
267+
checkEvaluation(Ceil(longLit), longLit, EmptyRow)
268+
checkEvaluation(Ceil(-doublePi), -3L, EmptyRow)
269+
checkEvaluation(Ceil(-floatPi.toDouble), -3L, EmptyRow)
270+
checkEvaluation(Ceil(-longLit), -longLit, EmptyRow)
261271
}
262272

263273
test("floor") {
@@ -268,6 +278,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
268278
checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3))
269279
checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0))
270280
checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0))
281+
282+
val doublePi: Double = 3.1415
283+
val floatPi: Float = 3.1415f
284+
val longLit: Long = 12345678901234567L
285+
checkEvaluation(Floor(doublePi), 3L, EmptyRow)
286+
checkEvaluation(Floor(floatPi.toDouble), 3L, EmptyRow)
287+
checkEvaluation(Floor(longLit), longLit, EmptyRow)
288+
checkEvaluation(Floor(-doublePi), -4L, EmptyRow)
289+
checkEvaluation(Floor(-floatPi.toDouble), -4L, EmptyRow)
290+
checkEvaluation(Floor(-longLit), -longLit, EmptyRow)
271291
}
272292

273293
test("factorial") {

sql/core/src/test/resources/sql-tests/inputs/operators.sql

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,9 @@ select cot(-1);
6464
select ceiling(0);
6565
select ceiling(1);
6666
select ceil(1234567890123456);
67-
select ceil(12345678901234567);
6867
select ceiling(1234567890123456);
69-
select ceiling(12345678901234567);
7068

7169
-- floor
7270
select floor(0);
7371
select floor(1);
7472
select floor(1234567890123456);
75-
select floor(12345678901234567);

sql/core/src/test/resources/sql-tests/results/operators.sql.out

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 38
2+
-- Number of queries: 45
33

44

55
-- !query 0
@@ -321,15 +321,15 @@ struct<COT(CAST(-1 AS DOUBLE)):double>
321321
-- !query 38
322322
select ceiling(0)
323323
-- !query 38 schema
324-
struct<CEIL(CAST(0 AS BIGINT)):bigint>
324+
struct<CEIL(CAST(0 AS DOUBLE)):bigint>
325325
-- !query 38 output
326326
0
327327

328328

329329
-- !query 39
330330
select ceiling(1)
331331
-- !query 39 schema
332-
struct<CEIL(CAST(1 AS BIGINT)):bigint>
332+
struct<CEIL(CAST(1 AS DOUBLE)):bigint>
333333
-- !query 39 output
334334
1
335335

@@ -343,56 +343,32 @@ struct<CEIL(1234567890123456):bigint>
343343

344344

345345
-- !query 41
346-
select ceil(12345678901234567)
346+
select ceiling(1234567890123456)
347347
-- !query 41 schema
348-
struct<CEIL(12345678901234567):bigint>
348+
struct<CEIL(1234567890123456):bigint>
349349
-- !query 41 output
350-
12345678901234567
350+
1234567890123456
351351

352352

353353
-- !query 42
354-
select ceiling(1234567890123456)
354+
select floor(0)
355355
-- !query 42 schema
356-
struct<CEIL(1234567890123456):bigint>
356+
struct<FLOOR(CAST(0 AS DOUBLE)):bigint>
357357
-- !query 42 output
358-
1234567890123456
358+
0
359359

360360

361361
-- !query 43
362-
select ceiling(12345678901234567)
362+
select floor(1)
363363
-- !query 43 schema
364-
struct<CEIL(12345678901234567):bigint>
364+
struct<FLOOR(CAST(1 AS DOUBLE)):bigint>
365365
-- !query 43 output
366-
12345678901234567
367-
368-
369-
-- !query 44
370-
select floor(0)
371-
-- !query 44 schema
372-
struct<FLOOR(CAST(0 AS BIGINT)):bigint>
373-
-- !query 44 output
374-
0
375-
376-
377-
-- !query 45
378-
select floor(1)
379-
-- !query 45 schema
380-
struct<FLOOR(CAST(1 AS BIGINT)):bigint>
381-
-- !query 45 output
382366
1
383367

384368

385-
-- !query 46
369+
-- !query 44
386370
select floor(1234567890123456)
387-
-- !query 46 schema
371+
-- !query 44 schema
388372
struct<FLOOR(1234567890123456):bigint>
389-
-- !query 46 output
373+
-- !query 44 output
390374
1234567890123456
391-
392-
393-
-- !query 47
394-
select floor(12345678901234567)
395-
-- !query 47 schema
396-
struct<FLOOR(12345678901234567):bigint>
397-
-- !query 47 output
398-
12345678901234567

0 commit comments

Comments
 (0)