Skip to content

Commit c032b0b

Browse files
committed
[SPARK-8797] [SPARK-9146] [SPARK-9145] [SPARK-9147] Support NaN ordering and equality comparisons in Spark SQL
This patch addresses an issue where queries that sorted float or double columns containing NaN values could fail with "Comparison method violates its general contract!" errors from TimSort. The root of this problem is that `NaN > anything`, `NaN == anything`, and `NaN < anything` all return `false`. Per the design specified in SPARK-9079, we have decided that `NaN = NaN` should return true and that NaN should appear last when sorting in ascending order (i.e. it is larger than any other numeric value). In addition to implementing these semantics, this patch also adds canonicalization of NaN values in UnsafeRow, which is necessary in order to be able to do binary equality comparisons on equal NaNs that might have different bit representations (see SPARK-9147). Author: Josh Rosen <joshrosen@databricks.com> Closes #7194 from JoshRosen/nan and squashes the following commits: 983d4fc [Josh Rosen] Merge remote-tracking branch 'origin/master' into nan 88bd73c [Josh Rosen] Fix Row.equals() a702e2e [Josh Rosen] normalization -> canonicalization a7267cf [Josh Rosen] Normalize NaNs in UnsafeRow fe629ae [Josh Rosen] Merge remote-tracking branch 'origin/master' into nan fbb2a29 [Josh Rosen] Fix NaN comparisons in BinaryComparison expressions c1fd4fe [Josh Rosen] Fold NaN test into existing test framework b31eb19 [Josh Rosen] Uncomment failing tests 7fe67af [Josh Rosen] Support NaN == NaN (SPARK-9145) 58bad2c [Josh Rosen] Revert "Compare rows' string representations to work around NaN incomparability." fc6b4d2 [Josh Rosen] Update CodeGenerator 3998ef2 [Josh Rosen] Remove unused code a2ba2e7 [Josh Rosen] Fix prefix comparision for NaNs a30d371 [Josh Rosen] Compare rows' string representations to work around NaN incomparability. 6f03f85 [Josh Rosen] Fix bug in Double / Float ordering 42a1ad5 [Josh Rosen] Stop filtering NaNs in UnsafeExternalSortSuite bfca524 [Josh Rosen] Change ordering so that NaN is maximum value. 8d7be61 [Josh Rosen] Update randomized test to use ScalaTest's assume() b20837b [Josh Rosen] Add failing test for new NaN comparision ordering 5b88b2b [Josh Rosen] Fix compilation of CodeGenerationSuite d907b5b [Josh Rosen] Merge remote-tracking branch 'origin/master' into nan 630ebc5 [Josh Rosen] Specify an ordering for NaN values. 9bf195a [Josh Rosen] Re-enable NaNs in CodeGenerationSuite to produce more regression tests 13fc06a [Josh Rosen] Add regression test for NaN sorting issue f9efbb5 [Josh Rosen] Fix ORDER BY NULL e7dc4fb [Josh Rosen] Add very generic test for ordering 7d5c13e [Josh Rosen] Add regression test for SPARK-8782 (ORDER BY NULL) b55875a [Josh Rosen] Generate doubles and floats over entire possible range. 5acdd5c [Josh Rosen] Infinity and NaN are interesting. ab76cbd [Josh Rosen] Move code to Catalyst package. d2b4a4a [Josh Rosen] Add random data generator test utilities to Spark SQL.
1 parent 4d97be9 commit c032b0b

File tree

16 files changed

+243
-26
lines changed

16 files changed

+243
-26
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import org.apache.spark.annotation.Private;
2525
import org.apache.spark.unsafe.types.UTF8String;
26+
import org.apache.spark.util.Utils;
2627

2728
@Private
2829
public class PrefixComparators {
@@ -82,7 +83,7 @@ public static final class FloatPrefixComparator extends PrefixComparator {
8283
public int compare(long aPrefix, long bPrefix) {
8384
float a = Float.intBitsToFloat((int) aPrefix);
8485
float b = Float.intBitsToFloat((int) bPrefix);
85-
return (a < b) ? -1 : (a > b) ? 1 : 0;
86+
return Utils.nanSafeCompareFloats(a, b);
8687
}
8788

8889
public long computePrefix(float value) {
@@ -97,7 +98,7 @@ public static final class DoublePrefixComparator extends PrefixComparator {
9798
public int compare(long aPrefix, long bPrefix) {
9899
double a = Double.longBitsToDouble(aPrefix);
99100
double b = Double.longBitsToDouble(bPrefix);
100-
return (a < b) ? -1 : (a > b) ? 1 : 0;
101+
return Utils.nanSafeCompareDoubles(a, b);
101102
}
102103

103104
public long computePrefix(double value) {

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,6 +1586,34 @@ private[spark] object Utils extends Logging {
15861586
hashAbs
15871587
}
15881588

1589+
/**
1590+
* NaN-safe version of [[java.lang.Double.compare()]] which allows NaN values to be compared
1591+
* according to semantics where NaN == NaN and NaN > any non-NaN double.
1592+
*/
1593+
def nanSafeCompareDoubles(x: Double, y: Double): Int = {
1594+
val xIsNan: Boolean = java.lang.Double.isNaN(x)
1595+
val yIsNan: Boolean = java.lang.Double.isNaN(y)
1596+
if ((xIsNan && yIsNan) || (x == y)) 0
1597+
else if (xIsNan) 1
1598+
else if (yIsNan) -1
1599+
else if (x > y) 1
1600+
else -1
1601+
}
1602+
1603+
/**
1604+
* NaN-safe version of [[java.lang.Float.compare()]] which allows NaN values to be compared
1605+
* according to semantics where NaN == NaN and NaN > any non-NaN float.
1606+
*/
1607+
def nanSafeCompareFloats(x: Float, y: Float): Int = {
1608+
val xIsNan: Boolean = java.lang.Float.isNaN(x)
1609+
val yIsNan: Boolean = java.lang.Float.isNaN(y)
1610+
if ((xIsNan && yIsNan) || (x == y)) 0
1611+
else if (xIsNan) 1
1612+
else if (yIsNan) -1
1613+
else if (x > y) 1
1614+
else -1
1615+
}
1616+
15891617
/** Returns the system properties map that is thread-safe to iterator over. It gets the
15901618
* properties which have been set explicitly, as well as those for which only a default value
15911619
* has been defined. */

core/src/test/scala/org/apache/spark/util/UtilsSuite.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.util
1919

2020
import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream}
21+
import java.lang.{Double => JDouble, Float => JFloat}
2122
import java.net.{BindException, ServerSocket, URI}
2223
import java.nio.{ByteBuffer, ByteOrder}
2324
import java.text.DecimalFormatSymbols
@@ -689,4 +690,34 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
689690
// scalastyle:on println
690691
assert(buffer.toString === "t circular test circular\n")
691692
}
693+
694+
test("nanSafeCompareDoubles") {
695+
def shouldMatchDefaultOrder(a: Double, b: Double): Unit = {
696+
assert(Utils.nanSafeCompareDoubles(a, b) === JDouble.compare(a, b))
697+
assert(Utils.nanSafeCompareDoubles(b, a) === JDouble.compare(b, a))
698+
}
699+
shouldMatchDefaultOrder(0d, 0d)
700+
shouldMatchDefaultOrder(0d, 1d)
701+
shouldMatchDefaultOrder(Double.MinValue, Double.MaxValue)
702+
assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NaN) === 0)
703+
assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.PositiveInfinity) === 1)
704+
assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NegativeInfinity) === 1)
705+
assert(Utils.nanSafeCompareDoubles(Double.PositiveInfinity, Double.NaN) === -1)
706+
assert(Utils.nanSafeCompareDoubles(Double.NegativeInfinity, Double.NaN) === -1)
707+
}
708+
709+
test("nanSafeCompareFloats") {
710+
def shouldMatchDefaultOrder(a: Float, b: Float): Unit = {
711+
assert(Utils.nanSafeCompareFloats(a, b) === JFloat.compare(a, b))
712+
assert(Utils.nanSafeCompareFloats(b, a) === JFloat.compare(b, a))
713+
}
714+
shouldMatchDefaultOrder(0f, 0f)
715+
shouldMatchDefaultOrder(1f, 1f)
716+
shouldMatchDefaultOrder(Float.MinValue, Float.MaxValue)
717+
assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NaN) === 0)
718+
assert(Utils.nanSafeCompareFloats(Float.NaN, Float.PositiveInfinity) === 1)
719+
assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NegativeInfinity) === 1)
720+
assert(Utils.nanSafeCompareFloats(Float.PositiveInfinity, Float.NaN) === -1)
721+
assert(Utils.nanSafeCompareFloats(Float.NegativeInfinity, Float.NaN) === -1)
722+
}
692723
}

core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,29 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
4747
forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
4848
forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
4949
}
50+
51+
test("float prefix comparator handles NaN properly") {
52+
val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001)
53+
val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff)
54+
assert(nan1.isNaN)
55+
assert(nan2.isNaN)
56+
val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1)
57+
val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2)
58+
assert(nan1Prefix === nan2Prefix)
59+
val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue)
60+
assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1)
61+
}
62+
63+
test("double prefix comparator handles NaNs properly") {
64+
val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L)
65+
val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)
66+
assert(nan1.isNaN)
67+
assert(nan2.isNaN)
68+
val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1)
69+
val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2)
70+
assert(nan1Prefix === nan2Prefix)
71+
val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue)
72+
assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1)
73+
}
74+
5075
}

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,9 @@ public void setLong(int ordinal, long value) {
215215
public void setDouble(int ordinal, double value) {
216216
assertIndexIsValid(ordinal);
217217
setNotNullAt(ordinal);
218+
if (Double.isNaN(value)) {
219+
value = Double.NaN;
220+
}
218221
PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value);
219222
}
220223

@@ -243,6 +246,9 @@ public void setByte(int ordinal, byte value) {
243246
public void setFloat(int ordinal, float value) {
244247
assertIndexIsValid(ordinal);
245248
setNotNullAt(ordinal);
249+
if (Float.isNaN(value)) {
250+
value = Float.NaN;
251+
}
246252
PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
247253
}
248254

sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -403,20 +403,28 @@ trait Row extends Serializable {
403403
if (!isNullAt(i)) {
404404
val o1 = get(i)
405405
val o2 = other.get(i)
406-
if (o1.isInstanceOf[Array[Byte]]) {
407-
// handle equality of Array[Byte]
408-
val b1 = o1.asInstanceOf[Array[Byte]]
409-
if (!o2.isInstanceOf[Array[Byte]] ||
410-
!java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
406+
o1 match {
407+
case b1: Array[Byte] =>
408+
if (!o2.isInstanceOf[Array[Byte]] ||
409+
!java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
410+
return false
411+
}
412+
case f1: Float if java.lang.Float.isNaN(f1) =>
413+
if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
414+
return false
415+
}
416+
case d1: Double if java.lang.Double.isNaN(d1) =>
417+
if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
418+
return false
419+
}
420+
case _ => if (o1 != o2) {
411421
return false
412422
}
413-
} else if (o1 != o2) {
414-
return false
415423
}
416424
}
417425
i += 1
418426
}
419-
return true
427+
true
420428
}
421429

422430
/* ---------------------- utility methods for Scala ---------------------- */

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ class CodeGenContext {
194194
*/
195195
def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match {
196196
case BinaryType => s"java.util.Arrays.equals($c1, $c2)"
197+
case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2"
198+
case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2"
197199
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
198200
case other => s"$c1.equals($c2)"
199201
}
@@ -204,6 +206,8 @@ class CodeGenContext {
204206
def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
205207
// java boolean doesn't support > or < operator
206208
case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))"
209+
case DoubleType => s"org.apache.spark.util.Utils.nanSafeCompareDoubles($c1, $c2)"
210+
case FloatType => s"org.apache.spark.util.Utils.nanSafeCompareFloats($c1, $c2)"
207211
// use c1 - c2 may overflow
208212
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
209213
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils
2222
import org.apache.spark.sql.catalyst.expressions.codegen._
2323
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2424
import org.apache.spark.sql.types._
25+
import org.apache.spark.util.Utils
2526

2627

2728
object InterpretedPredicate {
@@ -222,7 +223,9 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
222223
abstract class BinaryComparison extends BinaryOperator with Predicate {
223224

224225
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
225-
if (ctx.isPrimitiveType(left.dataType)) {
226+
if (ctx.isPrimitiveType(left.dataType)
227+
&& left.dataType != FloatType
228+
&& left.dataType != DoubleType) {
226229
// faster version
227230
defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2")
228231
} else {
@@ -254,8 +257,15 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
254257
override def symbol: String = "="
255258

256259
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
257-
if (left.dataType != BinaryType) input1 == input2
258-
else java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
260+
if (left.dataType == FloatType) {
261+
Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
262+
} else if (left.dataType == DoubleType) {
263+
Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
264+
} else if (left.dataType != BinaryType) {
265+
input1 == input2
266+
} else {
267+
java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
268+
}
259269
}
260270

261271
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -280,7 +290,11 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
280290
} else if (input1 == null || input2 == null) {
281291
false
282292
} else {
283-
if (left.dataType != BinaryType) {
293+
if (left.dataType == FloatType) {
294+
Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
295+
} else if (left.dataType == DoubleType) {
296+
Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
297+
} else if (left.dataType != BinaryType) {
284298
input1 == input2
285299
} else {
286300
java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag
2323

2424
import org.apache.spark.annotation.DeveloperApi
2525
import org.apache.spark.sql.catalyst.ScalaReflectionLock
26+
import org.apache.spark.util.Utils
2627

2728
/**
2829
* :: DeveloperApi ::
@@ -37,7 +38,9 @@ class DoubleType private() extends FractionalType {
3738
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
3839
private[sql] val numeric = implicitly[Numeric[Double]]
3940
private[sql] val fractional = implicitly[Fractional[Double]]
40-
private[sql] val ordering = implicitly[Ordering[InternalType]]
41+
private[sql] val ordering = new Ordering[Double] {
42+
override def compare(x: Double, y: Double): Int = Utils.nanSafeCompareDoubles(x, y)
43+
}
4144
private[sql] val asIntegral = DoubleAsIfIntegral
4245

4346
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag
2323

2424
import org.apache.spark.annotation.DeveloperApi
2525
import org.apache.spark.sql.catalyst.ScalaReflectionLock
26+
import org.apache.spark.util.Utils
2627

2728
/**
2829
* :: DeveloperApi ::
@@ -37,7 +38,9 @@ class FloatType private() extends FractionalType {
3738
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
3839
private[sql] val numeric = implicitly[Numeric[Float]]
3940
private[sql] val fractional = implicitly[Fractional[Float]]
40-
private[sql] val ordering = implicitly[Ordering[InternalType]]
41+
private[sql] val ordering = new Ordering[Float] {
42+
override def compare(x: Float, y: Float): Int = Utils.nanSafeCompareFloats(x, y)
43+
}
4144
private[sql] val asIntegral = FloatAsIfIntegral
4245

4346
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,14 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import scala.math._
21+
2022
import org.apache.spark.SparkFunSuite
23+
import org.apache.spark.sql.RandomDataGenerator
24+
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
2125
import org.apache.spark.sql.catalyst.dsl.expressions._
2226
import org.apache.spark.sql.catalyst.expressions.codegen._
27+
import org.apache.spark.sql.types.{DataTypeTestUtils, NullType, StructField, StructType}
2328

2429
/**
2530
* Additional tests for code generation.
@@ -43,6 +48,40 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
4348
futures.foreach(Await.result(_, 10.seconds))
4449
}
4550

51+
// Test GenerateOrdering for all common types. For each type, we construct random input rows that
52+
// contain two columns of that type, then for pairs of randomly-generated rows we check that
53+
// GenerateOrdering agrees with RowOrdering.
54+
(DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType =>
55+
test(s"GenerateOrdering with $dataType") {
56+
val rowOrdering = RowOrdering.forSchema(Seq(dataType, dataType))
57+
val genOrdering = GenerateOrdering.generate(
58+
BoundReference(0, dataType, nullable = true).asc ::
59+
BoundReference(1, dataType, nullable = true).asc :: Nil)
60+
val rowType = StructType(
61+
StructField("a", dataType, nullable = true) ::
62+
StructField("b", dataType, nullable = true) :: Nil)
63+
val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false)
64+
assume(maybeDataGenerator.isDefined)
65+
val randGenerator = maybeDataGenerator.get
66+
val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType)
67+
for (_ <- 1 to 50) {
68+
val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
69+
val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
70+
withClue(s"a = $a, b = $b") {
71+
assert(genOrdering.compare(a, a) === 0)
72+
assert(genOrdering.compare(b, b) === 0)
73+
assert(rowOrdering.compare(a, a) === 0)
74+
assert(rowOrdering.compare(b, b) === 0)
75+
assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a)))
76+
assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a)))
77+
assert(
78+
signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)),
79+
"Generated and non-generated orderings should agree")
80+
}
81+
}
82+
}
83+
}
84+
4685
test("SPARK-8443: split wide projections into blocks due to JVM code size limit") {
4786
val length = 5000
4887
val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1)))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,14 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
136136
checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
137137
}
138138

139-
private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_))
140-
private val largeValues = Seq(2, Decimal(2), Array(2.toByte), "b").map(Literal(_))
141-
142-
private val equalValues1 = smallValues
143-
private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_))
139+
private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d).map(Literal(_))
140+
private val largeValues =
141+
Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN).map(Literal(_))
142+
143+
private val equalValues1 =
144+
Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))
145+
private val equalValues2 =
146+
Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))
144147

145148
test("BinaryComparison: <") {
146149
for (i <- 0 until smallValues.length) {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,4 +316,26 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
316316
assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
317317
}
318318

319+
test("NaN canonicalization") {
320+
val fieldTypes: Array[DataType] = Array(FloatType, DoubleType)
321+
322+
val row1 = new SpecificMutableRow(fieldTypes)
323+
row1.setFloat(0, java.lang.Float.intBitsToFloat(0x7f800001))
324+
row1.setDouble(1, java.lang.Double.longBitsToDouble(0x7ff0000000000001L))
325+
326+
val row2 = new SpecificMutableRow(fieldTypes)
327+
row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff))
328+
row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL))
329+
330+
val converter = new UnsafeRowConverter(fieldTypes)
331+
val row1Buffer = new Array[Byte](converter.getSizeRequirement(row1))
332+
val row2Buffer = new Array[Byte](converter.getSizeRequirement(row2))
333+
converter.writeRow(
334+
row1, row1Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row1Buffer.length, null)
335+
converter.writeRow(
336+
row2, row2Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row2Buffer.length, null)
337+
338+
assert(row1Buffer.toSeq === row2Buffer.toSeq)
339+
}
340+
319341
}

0 commit comments

Comments
 (0)