Skip to content

Commit 8904791

Browse files
committed
[SPARK-2659][SQL] Fix division semantics for hive
Author: Michael Armbrust <michael@databricks.com> Closes #1557 from marmbrus/fixDivision and squashes the following commits: b85077f [Michael Armbrust] Fix unit tests. af98f29 [Michael Armbrust] Change DIV to long type 0c29ae8 [Michael Armbrust] Fix division semantics for hive
1 parent 9d8666c commit 8904791

File tree

7 files changed

+27
-9
lines changed

7 files changed

+27
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ trait HiveTypeCoercion {
5050
StringToIntegralCasts ::
5151
FunctionArgumentConversion ::
5252
CastNulls ::
53+
Division ::
5354
Nil
5455

5556
/**
@@ -317,6 +318,23 @@ trait HiveTypeCoercion {
317318
}
318319
}
319320

321+
/**
322+
* Hive only performs integral division with the DIV operator. The arguments to / are always
323+
* converted to fractional types.
324+
*/
325+
object Division extends Rule[LogicalPlan] {
326+
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
327+
// Skip nodes who's children have not been resolved yet.
328+
case e if !e.childrenResolved => e
329+
330+
// Decimal and Double remain the same
331+
case d: Divide if d.dataType == DoubleType => d
332+
case d: Divide if d.dataType == DecimalType => d
333+
334+
case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
335+
}
336+
}
337+
320338
/**
321339
* Ensures that NullType gets casted to some other types under certain circumstances.
322340
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class ConstantFoldingSuite extends PlanTest {
8383
Literal(10) as Symbol("2*3+4"),
8484
Literal(14) as Symbol("2*(3+4)"))
8585
.where(Literal(true))
86-
.groupBy(Literal(3))(Literal(3) as Symbol("9/3"))
86+
.groupBy(Literal(3.0))(Literal(3.0) as Symbol("9/3"))
8787
.analyze
8888

8989
comparePlans(optimized, correctAnswer)

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,8 @@ private[hive] object HiveQl {
925925
case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right))
926926
case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right))
927927
case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
928-
case Token(DIV(), left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
928+
case Token(DIV(), left :: right:: Nil) =>
929+
Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType)
929930
case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right))
930931

931932
/* Comparisons */
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
0 0 0 1 2
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
2.0 0.5 0.3333333333333333 0.002

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -350,12 +350,6 @@ abstract class HiveComparisonTest
350350

351351
val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n")
352352

353-
println("hive output")
354-
hive.foreach(println)
355-
356-
println("catalyst printout")
357-
catalyst.foreach(println)
358-
359353
if (recomputeCache) {
360354
logger.warn(s"Clearing cache files for failed test $testCaseName")
361355
hiveCacheFiles.foreach(_.delete())

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ class HiveQuerySuite extends HiveComparisonTest {
5252
"SELECT * FROM src WHERE key Between 1 and 2")
5353

5454
createQueryTest("div",
55-
"SELECT 1 DIV 2, 1 div 2, 1 dIv 2 FROM src LIMIT 1")
55+
"SELECT 1 DIV 2, 1 div 2, 1 dIv 2, 100 DIV 51, 100 DIV 49 FROM src LIMIT 1")
56+
57+
createQueryTest("division",
58+
"SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1")
5659

5760
test("Query expressed in SQL") {
5861
assert(sql("SELECT 1").collect() === Array(Seq(1)))

0 commit comments

Comments
 (0)