Skip to content

Commit fbb2a29

Browse files
committed
Fix NaN comparisons in BinaryComparison expressions
1 parent c1fd4fe commit fbb2a29

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,9 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
272272
abstract class BinaryComparison extends BinaryOperator with Predicate {
273273

274274
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
275-
if (ctx.isPrimitiveType(left.dataType)) {
275+
if (ctx.isPrimitiveType(left.dataType)
276+
&& left.dataType != FloatType
277+
&& left.dataType != DoubleType) {
276278
// faster version
277279
defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2")
278280
} else {
@@ -304,10 +306,19 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
304306
override def symbol: String = "="
305307

306308
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
307-
// Note that we do not have to do anything special here to handle NaN values: boxed Double and
308-
// Float NaNs will be equal (see Float.equals()' Javadoc for more details).
309-
if (left.dataType != BinaryType) input1 == input2
310-
else java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
309+
if (left.dataType == FloatType) {
310+
val f1 = input1.asInstanceOf[Float]
311+
val f2 = input2.asInstanceOf[Float]
312+
(java.lang.Float.isNaN(f1) && java.lang.Float.isNaN(f2)) || f1 == f2
313+
} else if (left.dataType == DoubleType) {
314+
val d1 = input1.asInstanceOf[Double]
315+
val d2 = input2.asInstanceOf[Double]
316+
(java.lang.Double.isNaN(d1) && java.lang.Double.isNaN(d2)) || d1 == d2
317+
} else if (left.dataType != BinaryType) {
318+
input1 == input2
319+
} else {
320+
java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
321+
}
311322
}
312323

313324
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -332,9 +343,15 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
332343
} else if (input1 == null || input2 == null) {
333344
false
334345
} else {
335-
// Note that we do not have to do anything special here to handle NaN values: boxed Double and
336-
// Float NaNs will be equal (see Float.equals()' Javadoc for more details).
337-
if (left.dataType != BinaryType) {
346+
if (left.dataType == FloatType) {
347+
val f1 = input1.asInstanceOf[Float]
348+
val f2 = input2.asInstanceOf[Float]
349+
(java.lang.Float.isNaN(f1) && java.lang.Float.isNaN(f2)) || f1 == f2
350+
} else if (left.dataType == DoubleType) {
351+
val d1 = input1.asInstanceOf[Double]
352+
val d2 = input2.asInstanceOf[Double]
353+
(java.lang.Double.isNaN(d1) && java.lang.Double.isNaN(d2)) || d1 == d2
354+
} else if (left.dataType != BinaryType) {
338355
input1 == input2
339356
} else {
340357
java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
146146
private val largeValues =
147147
Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN).map(Literal(_))
148148

149-
private val equalValues1 = smallValues ++ Seq(Float.NaN, Double.NaN).map(Literal(_))
149+
private val equalValues1 =
150+
Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))
150151
private val equalValues2 =
151152
Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))
152153

0 commit comments

Comments
 (0)