Skip to content

Commit 9ea55fe

Browse files
planga82cloud-fan
authored andcommitted
[SPARK-35207][SQL] Normalize hash function behavior with negative zero (floating point types)
### What changes were proposed in this pull request? Generally, we would expect that x = y => hash( x ) = hash( y ). However +-0 hash to different values for floating point types. ``` scala> spark.sql("select hash(cast('0.0' as double)), hash(cast('-0.0' as double))").show +-------------------------+--------------------------+ |hash(CAST(0.0 AS DOUBLE))|hash(CAST(-0.0 AS DOUBLE))| +-------------------------+--------------------------+ | -1670924195| -853646085| +-------------------------+--------------------------+ scala> spark.sql("select cast('0.0' as double) == cast('-0.0' as double)").show +--------------------------------------------+ |(CAST(0.0 AS DOUBLE) = CAST(-0.0 AS DOUBLE))| +--------------------------------------------+ | true| +--------------------------------------------+ ``` Here is an extract from IEEE 754: > The two zeros are distinguishable arithmetically only by either division-byzero ( producing appropriately signed infinities ) or else by the CopySign function recommended by IEEE 754 /854. Infinities, SNaNs, NaNs and Subnormal numbers necessitate four more special cases From this, I deduce that the hash function must produce the same result for 0 and -0. ### Why are the changes needed? It is a correctness issue ### Does this PR introduce _any_ user-facing change? This changes only affect to the hash function applied to -0 value in float and double types ### How was this patch tested? Unit testing and manual testing Closes #32496 from planga82/feature/spark35207_hashnegativezero. Authored-by: Pablo Langa <soypab@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent f7af9ab commit 9ea55fe

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

docs/sql-migration-guide.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ license: |
8787

8888
- In Spark 3.2, Spark supports `DayTimeIntervalType` and `YearMonthIntervalType` as inputs and outputs of `TRANSFORM` clause in Hive `SERDE` mode, the behavior is different between Hive `SERDE` mode and `ROW FORMAT DELIMITED` mode when these two types are used as inputs. In Hive `SERDE` mode, `DayTimeIntervalType` column is converted to `HiveIntervalDayTime`, its string format is `[-]?d h:m:s.n`, but in `ROW FORMAT DELIMITED` mode the format is `INTERVAL '[-]?d h:m:s.n' DAY TO TIME`. In Hive `SERDE` mode, `YearMonthIntervalType` column is converted to `HiveIntervalYearMonth`, its string format is `[-]?y-m`, but in `ROW FORMAT DELIMITED` mode the format is `INTERVAL '[-]?y-m' YEAR TO MONTH`.
8989

90+
- In Spark 3.2, `hash(0) == hash(-0)` for floating point types. Previously, different values were generated.
91+
9092
## Upgrading from Spark SQL 3.0 to 3.1
9193

9294
- In Spark 3.1, statistical aggregation function includes `std`, `stddev`, `stddev_samp`, `variance`, `var_samp`, `skewness`, `kurtosis`, `covar_samp`, `corr` will return `NULL` instead of `Double.NaN` when `DivideByZero` occurs during expression evaluation, for example, when `stddev_samp` applied on a single element set. In Spark version 3.0 and earlier, it will return `Double.NaN` in such case. To restore the behavior before Spark 3.1, you can set `spark.sql.legacy.statisticalAggregate` to `true`.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,11 +369,25 @@ abstract class HashExpression[E] extends Expression {
369369
protected def genHashBoolean(input: String, result: String): String =
370370
genHashInt(s"$input ? 1 : 0", result)
371371

372-
protected def genHashFloat(input: String, result: String): String =
373-
genHashInt(s"Float.floatToIntBits($input)", result)
372+
protected def genHashFloat(input: String, result: String): String = {
373+
s"""
374+
|if($input == -0.0f) {
375+
| ${genHashInt("0", result)}
376+
|} else {
377+
| ${genHashInt(s"Float.floatToIntBits($input)", result)}
378+
|}
379+
""".stripMargin
380+
}
374381

375-
protected def genHashDouble(input: String, result: String): String =
376-
genHashLong(s"Double.doubleToLongBits($input)", result)
382+
protected def genHashDouble(input: String, result: String): String = {
383+
s"""
384+
|if($input == -0.0d) {
385+
| ${genHashLong("0L", result)}
386+
|} else {
387+
| ${genHashLong(s"Double.doubleToLongBits($input)", result)}
388+
|}
389+
""".stripMargin
390+
}
377391

378392
protected def genHashDecimal(
379393
ctx: CodegenContext,
@@ -523,7 +537,9 @@ abstract class InterpretedHashFunction {
523537
case s: Short => hashInt(s, seed)
524538
case i: Int => hashInt(i, seed)
525539
case l: Long => hashLong(l, seed)
540+
case f: Float if (f == -0.0f) => hashInt(0, seed)
526541
case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed)
542+
case d: Double if (d == -0.0d) => hashLong(0L, seed)
527543
case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed)
528544
case d: Decimal =>
529545
val precision = dataType.asInstanceOf[DecimalType].precision

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,17 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
708708
checkEvaluation(HiveHash(Seq(yearMonth)), 1234)
709709
}
710710

711+
test("SPARK-35207: Compute hash consistent between -0.0 and 0.0") {
712+
def checkResult(exprs1: Expression, exprs2: Expression): Unit = {
713+
checkEvaluation(Murmur3Hash(Seq(exprs1), 42), Murmur3Hash(Seq(exprs2), 42).eval())
714+
checkEvaluation(XxHash64(Seq(exprs1), 42), XxHash64(Seq(exprs2), 42).eval())
715+
checkEvaluation(HiveHash(Seq(exprs1)), HiveHash(Seq(exprs2)).eval())
716+
}
717+
718+
checkResult(Literal.create(-0D, DoubleType), Literal.create(0D, DoubleType))
719+
checkResult(Literal.create(-0F, FloatType), Literal.create(0F, FloatType))
720+
}
721+
711722
private def testHash(inputSchema: StructType): Unit = {
712723
val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get
713724
val toRow = RowEncoder(inputSchema).createSerializer()

0 commit comments

Comments
 (0)