Skip to content

Commit 630ebc5

Browse files
committed
Specify an ordering for NaN values.
1 parent 9bf195a commit 630ebc5

File tree

6 files changed

+77
-4
lines changed

6 files changed

+77
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ class CodeGenContext {
182182
def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
183183
// java boolean doesn't support > or < operator
184184
case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))"
185+
case DoubleType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareDoubles($c1, $c2)"
186+
case FloatType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareFloats($c1, $c2)"
185187
// use c1 - c2 may overflow
186188
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
187189
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,24 @@ object TypeUtils {
7070
}
7171
x.length - y.length
7272
}
73+
74+
def compareDoubles(x: Double, y: Double): Int = {
75+
val xIsNan: Boolean = java.lang.Double.isNaN(x)
76+
val yIsNan: Boolean = java.lang.Double.isNaN(y)
77+
if ((xIsNan && yIsNan) || (x == y)) 0
78+
else if (xIsNan) -1
79+
else if (yIsNan) 1
80+
else if (x > y) -1
81+
else 1
82+
}
83+
84+
def compareFloats(x: Float, y: Float): Int = {
85+
val xIsNan: Boolean = java.lang.Float.isNaN(x)
86+
val yIsNan: Boolean = java.lang.Float.isNaN(y)
87+
if ((xIsNan && yIsNan) || (x == y)) 0
88+
else if (xIsNan) -1
89+
else if (yIsNan) 1
90+
else if (x > y) -1
91+
else 1
92+
}
7393
}

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag
2323

2424
import org.apache.spark.annotation.DeveloperApi
2525
import org.apache.spark.sql.catalyst.ScalaReflectionLock
26+
import org.apache.spark.sql.catalyst.util.TypeUtils
2627

2728
/**
2829
* :: DeveloperApi ::
@@ -39,7 +40,9 @@ class DoubleType private() extends FractionalType {
3940
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
4041
private[sql] val numeric = implicitly[Numeric[Double]]
4142
private[sql] val fractional = implicitly[Fractional[Double]]
42-
private[sql] val ordering = implicitly[Ordering[InternalType]]
43+
private[sql] val ordering = new Ordering[Double] {
44+
override def compare(x: Double, y: Double): Int = TypeUtils.compareDoubles(x, y)
45+
}
4346
private[sql] val asIntegral = DoubleAsIfIntegral
4447

4548
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag
2323

2424
import org.apache.spark.annotation.DeveloperApi
2525
import org.apache.spark.sql.catalyst.ScalaReflectionLock
26+
import org.apache.spark.sql.catalyst.util.TypeUtils
2627

2728
/**
2829
* :: DeveloperApi ::
@@ -39,7 +40,9 @@ class FloatType private() extends FractionalType {
3940
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
4041
private[sql] val numeric = implicitly[Numeric[Float]]
4142
private[sql] val fractional = implicitly[Fractional[Float]]
42-
private[sql] val ordering = implicitly[Ordering[InternalType]]
43+
private[sql] val ordering = new Ordering[Float] {
44+
override def compare(x: Float, y: Float): Int = TypeUtils.compareFloats(x, y)
45+
}
4346
private[sql] val asIntegral = FloatAsIfIntegral
4447

4548
/**
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.util
19+
20+
import org.apache.spark.SparkFunSuite
21+
22+
class TypeUtilsSuite extends SparkFunSuite {
23+
24+
import TypeUtils._
25+
26+
test("compareDoubles") {
27+
assert(compareDoubles(0, 0) === 0)
28+
assert(compareDoubles(1, 0) === -1)
29+
assert(compareDoubles(0, 1) === 1)
30+
assert(compareDoubles(Double.MinValue, Double.MaxValue) === 1)
31+
assert(compareDoubles(Double.NaN, Double.NaN) === 0)
32+
assert(compareDoubles(Double.NaN, Double.PositiveInfinity) === -1)
33+
assert(compareDoubles(Double.NaN, Double.NegativeInfinity) === -1)
34+
}
35+
36+
test("compareFloats") {
37+
assert(compareFloats(0, 0) === 0)
38+
assert(compareFloats(1, 0) === -1)
39+
assert(compareFloats(0, 1) === 1)
40+
assert(compareFloats(Float.MinValue, Float.MaxValue) === 1)
41+
assert(compareFloats(Float.NaN, Float.NaN) === 0)
42+
assert(compareFloats(Float.NaN, Float.PositiveInfinity) === -1)
43+
assert(compareFloats(Float.NaN, Float.NegativeInfinity) === -1)
44+
}
45+
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -739,13 +739,13 @@ class DataFrameSuite extends QueryTest {
739739
df.col("t.``")
740740
}
741741

742-
test("SPARK-XXXX: sort by float column containing NaN") {
742+
test("SPARK-8797: sort by float column containing NaN") {
743743
val inputData = Seq.fill(10)(Tuple1(Float.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toFloat))
744744
val df = Random.shuffle(inputData).toDF("a")
745745
df.orderBy("a").collect()
746746
}
747747

748-
test("SPARK-XXXX: sort by double column containing NaN") {
748+
test("SPARK-8797: sort by double column containing NaN") {
749749
val inputData = Seq.fill(10)(Tuple1(Double.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toDouble))
750750
val df = Random.shuffle(inputData).toDF("a")
751751
df.orderBy("a").collect()

0 commit comments

Comments
 (0)