Skip to content

Commit f86fba6

Browse files
committed
Changes implementation
Clean solution
1 parent 8b94eff commit f86fba6

File tree

3 files changed

+52
-4
lines changed

3 files changed

+52
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,11 +369,21 @@ abstract class HashExpression[E] extends Expression {
369369
protected def genHashBoolean(input: String, result: String): String =
370370
genHashInt(s"$input ? 1 : 0", result)
371371

372-
protected def genHashFloat(input: String, result: String): String =
373-
genHashInt(s"Float.floatToIntBits($input)", result)
372+
protected def genHashFloat(input: String, result: String): String = {
373+
s"if(Float.floatToIntBits($input) == Float.floatToIntBits(-0.0f)) {" +
374+
genHashInt(s"Float.floatToIntBits(0.0f)", result) +
375+
"}else{" +
376+
genHashInt(s"Float.floatToIntBits($input)", result) +
377+
"}"
378+
}
374379

375-
protected def genHashDouble(input: String, result: String): String =
376-
genHashLong(s"Double.doubleToLongBits($input)", result)
380+
protected def genHashDouble(input: String, result: String): String = {
381+
s"if(Double.doubleToLongBits($input) == Double.doubleToLongBits(-0.0d)) {" +
382+
genHashLong(s"Double.doubleToLongBits(0.0d)", result) +
383+
"}else{" +
384+
genHashLong(s"Double.doubleToLongBits($input)", result) +
385+
"}"
386+
}
377387

378388
protected def genHashDecimal(
379389
ctx: CodegenContext,
@@ -523,7 +533,9 @@ abstract class InterpretedHashFunction {
523533
case s: Short => hashInt(s, seed)
524534
case i: Int => hashInt(i, seed)
525535
case l: Long => hashLong(l, seed)
536+
case f: Float if (f == -0.0f) => hashInt(java.lang.Float.floatToIntBits(0.0f), seed)
526537
case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed)
538+
case d: Double if (d == -0.0d) => hashLong(java.lang.Double.doubleToLongBits(0.0d), seed)
527539
case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed)
528540
case d: Decimal =>
529541
val precision = dataType.asInstanceOf[DecimalType].precision

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,16 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
708708
checkEvaluation(HiveHash(Seq(yearMonth)), 1234)
709709
}
710710

711+
test("SPARK-35207: Compute hash consistent between -0.0 and 0.0") {
712+
def checkResult(exprs1: Expression, exprs2: Expression): Unit = {
713+
assert(Murmur3Hash(Seq(exprs1), 42).eval() == Murmur3Hash(Seq(exprs2), 42).eval())
714+
assert(XxHash64(Seq(exprs1), 42).eval() == XxHash64(Seq(exprs2), 42).eval())
715+
assert(HiveHash(Seq(exprs1)).eval() == HiveHash(Seq(exprs2)).eval())
716+
}
717+
checkResult(Literal.create(0D, DoubleType), Literal.create(-0D, DoubleType))
718+
checkResult(Literal.create(0L, LongType), Literal.create(-0L, LongType))
719+
}
720+
711721
private def testHash(inputSchema: StructType): Unit = {
712722
val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get
713723
val toRow = RowEncoder(inputSchema).createSerializer()

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,4 +654,30 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
654654
}
655655
}
656656
}
657+
658+
test("SPARK-35207: Compute hash consistent between -0.0 and 0.0 doubles with Codegen") {
659+
val data = Seq((0.0d, -1.0d, 1.0d))
660+
withTempPath { dir =>
661+
val path = dir.getCanonicalPath
662+
data.toDF("col1", "col2", "col3").write.parquet(path)
663+
sql(s"create table testHash(col1 double, col2 double, col3 double) " +
664+
s"using parquet location '$path'")
665+
sql("select hash(col1 / col2) == hash(col1 / col3) from testHash").collect()
666+
.foreach(row => assert(row.getBoolean(0) == true))
667+
sql("drop table testHash")
668+
}
669+
}
670+
671+
test("SPARK-35207: Compute hash consistent between -0.0 and 0.0 floats with Codegen") {
672+
val data = Seq((0.0f, -1.0f, 1.0f))
673+
withTempPath { dir =>
674+
val path = dir.getCanonicalPath
675+
data.toDF("col1", "col2", "col3").write.parquet(path)
676+
sql(s"create table testHash(col1 float, col2 float, col3 float) " +
677+
s"using parquet location '$path'")
678+
sql("select hash(col1 / col2) == hash(col1 / col3) from testHash").collect()
679+
.foreach(row => assert(row.getBoolean(0) == true))
680+
sql("drop table testHash")
681+
}
682+
}
657683
}

0 commit comments

Comments
 (0)