@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils
22
22
import org .apache .spark .sql .catalyst .expressions .codegen ._
23
23
import org .apache .spark .sql .catalyst .plans .logical .LogicalPlan
24
24
import org .apache .spark .sql .types ._
25
+ import org .apache .spark .util .Utils
25
26
26
27
27
28
object InterpretedPredicate {
@@ -257,13 +258,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
257
258
258
259
protected override def nullSafeEval (input1 : Any , input2 : Any ): Any = {
259
260
if (left.dataType == FloatType ) {
260
- val f1 = input1.asInstanceOf [Float ]
261
- val f2 = input2.asInstanceOf [Float ]
262
- (java.lang.Float .isNaN(f1) && java.lang.Float .isNaN(f2)) || f1 == f2
261
+ Utils .nanSafeCompareFloats(input1.asInstanceOf [Float ], input2.asInstanceOf [Float ]) == 0
263
262
} else if (left.dataType == DoubleType ) {
264
- val d1 = input1.asInstanceOf [Double ]
265
- val d2 = input2.asInstanceOf [Double ]
266
- (java.lang.Double .isNaN(d1) && java.lang.Double .isNaN(d2)) || d1 == d2
263
+ Utils .nanSafeCompareDoubles(input1.asInstanceOf [Double ], input2.asInstanceOf [Double ]) == 0
267
264
} else if (left.dataType != BinaryType ) {
268
265
input1 == input2
269
266
} else {
@@ -294,13 +291,9 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
294
291
false
295
292
} else {
296
293
if (left.dataType == FloatType ) {
297
- val f1 = input1.asInstanceOf [Float ]
298
- val f2 = input2.asInstanceOf [Float ]
299
- (java.lang.Float .isNaN(f1) && java.lang.Float .isNaN(f2)) || f1 == f2
294
+ Utils .nanSafeCompareFloats(input1.asInstanceOf [Float ], input2.asInstanceOf [Float ]) == 0
300
295
} else if (left.dataType == DoubleType ) {
301
- val d1 = input1.asInstanceOf [Double ]
302
- val d2 = input2.asInstanceOf [Double ]
303
- (java.lang.Double .isNaN(d1) && java.lang.Double .isNaN(d2)) || d1 == d2
296
+ Utils .nanSafeCompareDoubles(input1.asInstanceOf [Double ], input2.asInstanceOf [Double ]) == 0
304
297
} else if (left.dataType != BinaryType ) {
305
298
input1 == input2
306
299
} else {
0 commit comments