Skip to content

Commit e853afb

Browse files
cloud-fangatorsmile
authored andcommitted
[SPARK-26448][SQL] retain the difference between 0.0 and -0.0
## What changes were proposed in this pull request? In #23043 , we introduced a behavior change: Spark users are not able to distinguish 0.0 and -0.0 anymore. This PR proposes an alternative fix to the original bug, to retain the difference between 0.0 and -0.0 inside Spark. The idea is, we can rewrite the window partition key, join key and grouping key during logical phase, to normalize the special floating numbers. Thus only operators care about special floating numbers need to pay the perf overhead, and end users can distinguish -0.0. ## How was this patch tested? existing test Closes #23388 from cloud-fan/minor. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: gatorsmile <gatorsmile@gmail.com>
1 parent 49c062b commit e853afb

File tree

15 files changed

+436
-153
lines changed

15 files changed

+436
-153
lines changed

docs/sql-migration-guide-upgrade.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ displayTitle: Spark SQL Upgrading Guide
2525

2626
- In Spark version 2.4 and earlier, `Dataset.groupByKey` results to a grouped dataset with key attribute wrongly named as "value", if the key is non-struct type, e.g. int, string, array, etc. This is counterintuitive and makes the schema of aggregation queries weird. For example, the schema of `ds.groupByKey(...).count()` is `(value, count)`. Since Spark 3.0, we name the grouping attribute to "key". The old behaviour is preserved under a newly added configuration `spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue` with a default value of `false`.
2727

28-
- In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but users can still distinguish them via `Dataset.show`, `Dataset.collect` etc. Since Spark 3.0, float/double -0.0 is replaced by 0.0 internally, and users can't distinguish them any more.
28+
- In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but -0.0 and 0.0 are considered as different values when used in aggregate grouping keys, window partition keys and join keys. Since Spark 3.0, this bug is fixed. For example, `Seq(-0.0, 0.0).toDF("d").groupBy("d").count()` returns `[(0.0, 2)]` in Spark 3.0, and `[(0.0, 1), (-0.0, 1)]` in Spark 2.4 and earlier.
2929

3030
- In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be undefined.
3131

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -198,46 +198,11 @@ protected final void writeLong(long offset, long value) {
198198
Platform.putLong(getBuffer(), offset, value);
199199
}
200200

201-
// We need to take care of NaN and -0.0 in several places:
202-
// 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be
203-
// treated as same.
204-
// 2. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong
205-
// to the same group.
206-
// 3. As join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be
207-
// treated as same.
208-
// 4. As window partition keys, different NaNs should be treated as same, `-0.0` and `0.0`
209-
// should be treated as same.
210-
//
211-
// Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we
212-
// recursively compare the fields/elements, so it's also fine.
213-
//
214-
// Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different
215-
// NaNs have different binary representation, and the same thing happens for -0.0 and 0.0.
216-
//
217-
// Here we normalize NaN and -0.0, so that `UnsafeProjection` will normalize them when writing
218-
// float/double columns and nested fields to `UnsafeRow`.
219-
//
220-
// Note that, we must do this for all the `UnsafeProjection`s, not only the ones that extract
221-
// join/grouping/window partition keys. `UnsafeProjection` copies unsafe data directly for complex
222-
// types, so nested float/double may not be normalized. We need to make sure that all the unsafe
223-
// data(`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`) will have flat/double normalized during
224-
// creation.
225201
protected final void writeFloat(long offset, float value) {
226-
if (Float.isNaN(value)) {
227-
value = Float.NaN;
228-
} else if (value == -0.0f) {
229-
value = 0.0f;
230-
}
231202
Platform.putFloat(getBuffer(), offset, value);
232203
}
233204

234-
// See comments for `writeFloat`.
235205
protected final void writeDouble(long offset, double value) {
236-
if (Double.isNaN(value)) {
237-
value = Double.NaN;
238-
} else if (value == -0.0d) {
239-
value = 0.0d;
240-
}
241206
Platform.putDouble(getBuffer(), offset, value);
242207
}
243208
}
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, LambdaFunction, NamedLambdaVariable, UnaryExpression}
21+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
22+
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
23+
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window}
24+
import org.apache.spark.sql.catalyst.rules.Rule
25+
import org.apache.spark.sql.types._
26+
27+
/**
28+
* We need to take care of special floating numbers (NaN and -0.0) in several places:
29+
* 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be
30+
* treated as same.
31+
* 2. In aggregate grouping keys, different NaNs should belong to the same group, -0.0 and 0.0
32+
* should belong to the same group.
33+
* 3. In join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be
34+
* treated as same.
35+
* 4. In window partition keys, different NaNs should belong to the same partition, -0.0 and 0.0
36+
* should belong to the same partition.
37+
*
38+
* Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we
39+
* recursively compare the fields/elements, so it's also fine.
40+
*
41+
* Case 2, 3 and 4 are problematic, as Spark SQL turns grouping/join/window partition keys into
42+
* binary `UnsafeRow` and compare the binary data directly. Different NaNs have different binary
43+
* representation, and the same thing happens for -0.0 and 0.0.
44+
*
45+
* This rule normalizes NaN and -0.0 in window partition keys, join keys and aggregate grouping
46+
* keys.
47+
*
48+
* Ideally we should do the normalization in the physical operators that compare the
49+
* binary `UnsafeRow` directly. We don't need this normalization if the Spark SQL execution engine
50+
* is not optimized to run on binary data. This rule is created to simplify the implementation, so
51+
* that we have a single place to do normalization, which is more maintainable.
52+
*
53+
* Note that, this rule must be executed at the end of optimizer, because the optimizer may create
54+
* new joins(the subquery rewrite) and new join conditions(the join reorder).
55+
*/
56+
object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
57+
58+
def apply(plan: LogicalPlan): LogicalPlan = plan match {
59+
// A subquery will be rewritten into join later, and will go through this rule
60+
// eventually. Here we skip subquery, as we only need to run this rule once.
61+
case _: Subquery => plan
62+
63+
case _ => plan transform {
64+
case w: Window if w.partitionSpec.exists(p => needNormalize(p.dataType)) =>
65+
// Although the `windowExpressions` may refer to `partitionSpec` expressions, we don't need
66+
// to normalize the `windowExpressions`, as they are executed per input row and should take
67+
// the input row as it is.
68+
w.copy(partitionSpec = w.partitionSpec.map(normalize))
69+
70+
// Only hash join and sort merge join need the normalization. Here we catch all Joins with
71+
// join keys, assuming Joins with join keys are always planned as hash join or sort merge
72+
// join. It's very unlikely that we will break this assumption in the near future.
73+
case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _)
74+
// The analyzer guarantees left and right joins keys are of the same data type. Here we
75+
// only need to check join keys of one side.
76+
if leftKeys.exists(k => needNormalize(k.dataType)) =>
77+
val newLeftJoinKeys = leftKeys.map(normalize)
78+
val newRightJoinKeys = rightKeys.map(normalize)
79+
val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map {
80+
case (l, r) => EqualTo(l, r)
81+
} ++ condition
82+
j.copy(condition = Some(newConditions.reduce(And)))
83+
84+
// TODO: ideally Aggregate should also be handled here, but its grouping expressions are
85+
// mixed in its aggregate expressions. It's unreliable to change the grouping expressions
86+
// here. For now we normalize grouping expressions in `AggUtils` during planning.
87+
}
88+
}
89+
90+
private def needNormalize(dt: DataType): Boolean = dt match {
91+
case FloatType | DoubleType => true
92+
case StructType(fields) => fields.exists(f => needNormalize(f.dataType))
93+
case ArrayType(et, _) => needNormalize(et)
94+
// Currently MapType is not comparable and analyzer should fail earlier if this case happens.
95+
case _: MapType =>
96+
throw new IllegalStateException("grouping/join/window partition keys cannot be map type.")
97+
case _ => false
98+
}
99+
100+
private[sql] def normalize(expr: Expression): Expression = expr match {
101+
case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
102+
NormalizeNaNAndZero(expr)
103+
104+
case CreateNamedStruct(children) =>
105+
CreateNamedStruct(children.map(normalize))
106+
107+
case CreateNamedStructUnsafe(children) =>
108+
CreateNamedStructUnsafe(children.map(normalize))
109+
110+
case CreateArray(children) =>
111+
CreateArray(children.map(normalize))
112+
113+
case CreateMap(children) =>
114+
CreateMap(children.map(normalize))
115+
116+
case a: Alias if needNormalize(a.dataType) =>
117+
a.withNewChildren(Seq(normalize(a.child)))
118+
119+
case _ if expr.dataType.isInstanceOf[StructType] && needNormalize(expr.dataType) =>
120+
val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i =>
121+
normalize(GetStructField(expr, i))
122+
}
123+
CreateStruct(fields)
124+
125+
case _ if expr.dataType.isInstanceOf[ArrayType] && needNormalize(expr.dataType) =>
126+
val ArrayType(et, containsNull) = expr.dataType
127+
val lv = NamedLambdaVariable("arg", et, containsNull)
128+
val function = normalize(lv)
129+
ArrayTransform(expr, LambdaFunction(function, Seq(lv)))
130+
131+
case _ => expr
132+
}
133+
}
134+
135+
case class NormalizeNaNAndZero(child: Expression) extends UnaryExpression with ExpectsInputTypes {
136+
137+
override def dataType: DataType = child.dataType
138+
139+
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(FloatType, DoubleType))
140+
141+
private lazy val normalizer: Any => Any = child.dataType match {
142+
case FloatType => (input: Any) => {
143+
val f = input.asInstanceOf[Float]
144+
if (f.isNaN) {
145+
Float.NaN
146+
} else if (f == -0.0f) {
147+
0.0f
148+
} else {
149+
f
150+
}
151+
}
152+
153+
case DoubleType => (input: Any) => {
154+
val d = input.asInstanceOf[Double]
155+
if (d.isNaN) {
156+
Double.NaN
157+
} else if (d == -0.0d) {
158+
0.0d
159+
} else {
160+
d
161+
}
162+
}
163+
}
164+
165+
override def nullSafeEval(input: Any): Any = {
166+
normalizer(input)
167+
}
168+
169+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
170+
val codeToNormalize = child.dataType match {
171+
case FloatType => (f: String) => {
172+
s"""
173+
|if (Float.isNaN($f)) {
174+
| ${ev.value} = Float.NaN;
175+
|} else if ($f == -0.0f) {
176+
| ${ev.value} = 0.0f;
177+
|} else {
178+
| ${ev.value} = $f;
179+
|}
180+
""".stripMargin
181+
}
182+
183+
case DoubleType => (d: String) => {
184+
s"""
185+
|if (Double.isNaN($d)) {
186+
| ${ev.value} = Double.NaN;
187+
|} else if ($d == -0.0d) {
188+
| ${ev.value} = 0.0d;
189+
|} else {
190+
| ${ev.value} = $d;
191+
|}
192+
""".stripMargin
193+
}
194+
}
195+
196+
nullSafeCodeGen(ctx, ev, codeToNormalize)
197+
}
198+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
180180
CollapseProject,
181181
RemoveNoopOperators) :+
182182
Batch("UpdateAttributeReferences", Once,
183-
UpdateNullabilityInAttributeReferences)
183+
UpdateNullabilityInAttributeReferences) :+
184+
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
185+
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers)
184186
}
185187

186188
/**
@@ -210,7 +212,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
210212
PullupCorrelatedPredicates.ruleName ::
211213
RewriteCorrelatedScalarSubquery.ruleName ::
212214
RewritePredicateSubquery.ruleName ::
213-
PullOutPythonUDFInJoinCondition.ruleName :: Nil
215+
PullOutPythonUDFInJoinCondition.ruleName ::
216+
NormalizeFloatingNumbers.ruleName :: Nil
214217

215218
/**
216219
* Optimize all the subqueries inside expression.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
105105
(JoinType, Seq[Expression], Seq[Expression],
106106
Option[Expression], LogicalPlan, LogicalPlan, JoinHint)
107107

108-
def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
109-
case join @ Join(left, right, joinType, condition, hint) =>
108+
def unapply(join: Join): Option[ReturnType] = join match {
109+
case Join(left, right, joinType, condition, hint) =>
110110
logDebug(s"Considering join on: $condition")
111111
// Find equi-join predicates that can be evaluated before the join, and thus can be used
112112
// as join keys.
@@ -140,7 +140,6 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
140140
} else {
141141
None
142142
}
143-
case _ => None
144143
}
145144
}
146145

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -246,22 +246,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB
246246
// assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
247247
}
248248

249-
testBothCodegenAndInterpreted("NaN canonicalization") {
250-
val factory = UnsafeProjection
251-
val fieldTypes: Array[DataType] = Array(FloatType, DoubleType)
252-
253-
val row1 = new SpecificInternalRow(fieldTypes)
254-
row1.setFloat(0, java.lang.Float.intBitsToFloat(0x7f800001))
255-
row1.setDouble(1, java.lang.Double.longBitsToDouble(0x7ff0000000000001L))
256-
257-
val row2 = new SpecificInternalRow(fieldTypes)
258-
row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff))
259-
row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL))
260-
261-
val converter = factory.create(fieldTypes)
262-
assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes)
263-
}
264-
265249
testBothCodegenAndInterpreted("basic conversion with struct type") {
266250
val factory = UnsafeProjection
267251
val fieldTypes: Array[DataType] = Array(

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -49,25 +49,4 @@ class UnsafeRowWriterSuite extends SparkFunSuite {
4949
// The two rows should be the equal
5050
assert(res1 == res2)
5151
}
52-
53-
test("SPARK-26021: normalize float/double NaN and -0.0") {
54-
val unsafeRowWriter1 = new UnsafeRowWriter(4)
55-
unsafeRowWriter1.resetRowWriter()
56-
unsafeRowWriter1.write(0, Float.NaN)
57-
unsafeRowWriter1.write(1, Double.NaN)
58-
unsafeRowWriter1.write(2, 0.0f)
59-
unsafeRowWriter1.write(3, 0.0)
60-
val res1 = unsafeRowWriter1.getRow
61-
62-
val unsafeRowWriter2 = new UnsafeRowWriter(4)
63-
unsafeRowWriter2.resetRowWriter()
64-
unsafeRowWriter2.write(0, 0.0f/0.0f)
65-
unsafeRowWriter2.write(1, 0.0/0.0)
66-
unsafeRowWriter2.write(2, -0.0f)
67-
unsafeRowWriter2.write(3, -0.0)
68-
val res2 = unsafeRowWriter2.getRow
69-
70-
// The two rows should be the equal
71-
assert(res1 == res2)
72-
}
7352
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.aggregate
1919

2020
import org.apache.spark.sql.catalyst.expressions._
2121
import org.apache.spark.sql.catalyst.expressions.aggregate._
22+
import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers
2223
import org.apache.spark.sql.execution.SparkPlan
2324
import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec}
24-
import org.apache.spark.sql.internal.SQLConf
2525

2626
/**
2727
* Utility functions used by the query planner to convert our plan to new aggregation code path.
@@ -35,12 +35,20 @@ object AggUtils {
3535
initialInputBufferOffset: Int = 0,
3636
resultExpressions: Seq[NamedExpression] = Nil,
3737
child: SparkPlan): SparkPlan = {
38+
// Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because
39+
// `groupingExpressions` is not extracted during logical phase.
40+
val normalizedGroupingExpressions = groupingExpressions.map { e =>
41+
NormalizeFloatingNumbers.normalize(e) match {
42+
case n: NamedExpression => n
43+
case other => Alias(other, e.name)(exprId = e.exprId)
44+
}
45+
}
3846
val useHash = HashAggregateExec.supportsAggregate(
3947
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
4048
if (useHash) {
4149
HashAggregateExec(
4250
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
43-
groupingExpressions = groupingExpressions,
51+
groupingExpressions = normalizedGroupingExpressions,
4452
aggregateExpressions = aggregateExpressions,
4553
aggregateAttributes = aggregateAttributes,
4654
initialInputBufferOffset = initialInputBufferOffset,
@@ -53,7 +61,7 @@ object AggUtils {
5361
if (objectHashEnabled && useObjectHash) {
5462
ObjectHashAggregateExec(
5563
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
56-
groupingExpressions = groupingExpressions,
64+
groupingExpressions = normalizedGroupingExpressions,
5765
aggregateExpressions = aggregateExpressions,
5866
aggregateAttributes = aggregateAttributes,
5967
initialInputBufferOffset = initialInputBufferOffset,
@@ -62,7 +70,7 @@ object AggUtils {
6270
} else {
6371
SortAggregateExec(
6472
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
65-
groupingExpressions = groupingExpressions,
73+
groupingExpressions = normalizedGroupingExpressions,
6674
aggregateExpressions = aggregateExpressions,
6775
aggregateAttributes = aggregateAttributes,
6876
initialInputBufferOffset = initialInputBufferOffset,

0 commit comments

Comments
 (0)