Skip to content

Commit 88bd73c

Browse files
committed
Fix Row.equals()
1 parent a702e2e commit 88bd73c

File tree

2 files changed

+11
-20
lines changed

2 files changed

+11
-20
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -406,17 +406,15 @@ trait Row extends Serializable {
406406
o1 match {
407407
case b1: Array[Byte] =>
408408
if (!o2.isInstanceOf[Array[Byte]] ||
409-
!java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
409+
!java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
410410
return false
411411
}
412-
case f1: Float =>
413-
if (!o2.isInstanceOf[Float] ||
414-
(java.lang.Float.isNaN(f1) && !java.lang.Float.isNaN(o2.asInstanceOf[Float]))) {
415-
return false
412+
case f1: Float if java.lang.Float.isNaN(f1) =>
413+
if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
414+
return false
416415
}
417-
case d1: Double =>
418-
if (!o2.isInstanceOf[Double] ||
419-
(java.lang.Double.isNaN(d1) && !java.lang.Double.isNaN(o2.asInstanceOf[Double]))) {
416+
case d1: Double if java.lang.Double.isNaN(d1) =>
417+
if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
420418
return false
421419
}
422420
case _ => if (o1 != o2) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils
2222
import org.apache.spark.sql.catalyst.expressions.codegen._
2323
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2424
import org.apache.spark.sql.types._
25+
import org.apache.spark.util.Utils
2526

2627

2728
object InterpretedPredicate {
@@ -257,13 +258,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
257258

258259
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
259260
if (left.dataType == FloatType) {
260-
val f1 = input1.asInstanceOf[Float]
261-
val f2 = input2.asInstanceOf[Float]
262-
(java.lang.Float.isNaN(f1) && java.lang.Float.isNaN(f2)) || f1 == f2
261+
Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
263262
} else if (left.dataType == DoubleType) {
264-
val d1 = input1.asInstanceOf[Double]
265-
val d2 = input2.asInstanceOf[Double]
266-
(java.lang.Double.isNaN(d1) && java.lang.Double.isNaN(d2)) || d1 == d2
263+
Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
267264
} else if (left.dataType != BinaryType) {
268265
input1 == input2
269266
} else {
@@ -294,13 +291,9 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
294291
false
295292
} else {
296293
if (left.dataType == FloatType) {
297-
val f1 = input1.asInstanceOf[Float]
298-
val f2 = input2.asInstanceOf[Float]
299-
(java.lang.Float.isNaN(f1) && java.lang.Float.isNaN(f2)) || f1 == f2
294+
Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
300295
} else if (left.dataType == DoubleType) {
301-
val d1 = input1.asInstanceOf[Double]
302-
val d2 = input2.asInstanceOf[Double]
303-
(java.lang.Double.isNaN(d1) && java.lang.Double.isNaN(d2)) || d1 == d2
296+
Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
304297
} else if (left.dataType != BinaryType) {
305298
input1 == input2
306299
} else {

0 commit comments

Comments
 (0)