Skip to content

Commit 33460c5

Browse files
cloud-fandongjoon-hyun
authored andcommitted
[SPARK-26021][2.4][SQL][FOLLOWUP] only deal with NaN and -0.0 in UnsafeWriter
backport #23239 to 2.4 --------- ## What changes were proposed in this pull request? A followup of #23043 There are 4 places we need to deal with NaN and -0.0: 1. comparison expressions. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same. 2. Join keys. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same. 3. grouping keys. `-0.0` and `0.0` should be assigned to the same group. Different NaNs should be assigned to the same group. 4. window partition keys. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same. The case 1 is OK. Our comparison already handles NaN and -0.0, and for struct/array/map, we will recursively compare the fields/elements. Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different NaNs have different binary representation, and the same thing happens for -0.0 and 0.0. To fix it, a simple solution is: normalize float/double when building unsafe data (`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`). Then we don't need to worry about it anymore. Following this direction, this PR moves the handling of NaN and -0.0 from `Platform` to `UnsafeWriter`, so that places like `UnsafeRow.setFloat` will not handle them, which reduces the perf overhead. It's also easier to add comments explaining why we do it in `UnsafeWriter`. ## How was this patch tested? existing tests Closes #23265 from cloud-fan/minor. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent a073b1c commit 33460c5

File tree

6 files changed

+81
-24
lines changed

6 files changed

+81
-24
lines changed

common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,6 @@ public static float getFloat(Object object, long offset) {
120120
}
121121

122122
public static void putFloat(Object object, long offset, float value) {
123-
if (Float.isNaN(value)) {
124-
value = Float.NaN;
125-
} else if (value == -0.0f) {
126-
value = 0.0f;
127-
}
128123
_UNSAFE.putFloat(object, offset, value);
129124
}
130125

@@ -133,11 +128,6 @@ public static double getDouble(Object object, long offset) {
133128
}
134129

135130
public static void putDouble(Object object, long offset, double value) {
136-
if (Double.isNaN(value)) {
137-
value = Double.NaN;
138-
} else if (value == -0.0d) {
139-
value = 0.0d;
140-
}
141131
_UNSAFE.putDouble(object, offset, value);
142132
}
143133

common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -157,18 +157,4 @@ public void heapMemoryReuse() {
157157
Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7);
158158
Assert.assertEquals(obj3, onheap4.getBaseObject());
159159
}
160-
161-
@Test
162-
// SPARK-26021
163-
public void writeMinusZeroIsReplacedWithZero() {
164-
byte[] doubleBytes = new byte[Double.BYTES];
165-
byte[] floatBytes = new byte[Float.BYTES];
166-
Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d);
167-
Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f);
168-
double doubleFromPlatform = Platform.getDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET);
169-
float floatFromPlatform = Platform.getFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET);
170-
171-
Assert.assertEquals(Double.doubleToLongBits(0.0d), Double.doubleToLongBits(doubleFromPlatform));
172-
Assert.assertEquals(Float.floatToIntBits(0.0f), Float.floatToIntBits(floatFromPlatform));
173-
}
174160
}

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,46 @@ protected final void writeLong(long offset, long value) {
198198
Platform.putLong(getBuffer(), offset, value);
199199
}
200200

201+
// We need to take care of NaN and -0.0 in several places:
202+
// 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be
203+
// treated as same.
204+
// 2. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong
205+
// to the same group.
206+
// 3. As join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be
207+
// treated as same.
208+
// 4. As window partition keys, different NaNs should be treated as same, `-0.0` and `0.0`
209+
// should be treated as same.
210+
//
211+
// Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we
212+
// recursively compare the fields/elements, so it's also fine.
213+
//
214+
// Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different
215+
// NaNs have different binary representation, and the same thing happens for -0.0 and 0.0.
216+
//
217+
// Here we normalize NaN and -0.0, so that `UnsafeProjection` will normalize them when writing
218+
// float/double columns and nested fields to `UnsafeRow`.
219+
//
220+
// Note that, we must do this for all the `UnsafeProjection`s, not only the ones that extract
221+
// join/grouping/window partition keys. `UnsafeProjection` copies unsafe data directly for complex
222+
// types, so nested float/double may not be normalized. We need to make sure that all the unsafe
223+
// data(`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`) will have flat/double normalized during
224+
// creation.
201225
protected final void writeFloat(long offset, float value) {
226+
if (Float.isNaN(value)) {
227+
value = Float.NaN;
228+
} else if (value == -0.0f) {
229+
value = 0.0f;
230+
}
202231
Platform.putFloat(getBuffer(), offset, value);
203232
}
204233

234+
// See comments for `writeFloat`.
205235
protected final void writeDouble(long offset, double value) {
236+
if (Double.isNaN(value)) {
237+
value = Double.NaN;
238+
} else if (value == -0.0d) {
239+
value = 0.0d;
240+
}
206241
Platform.putDouble(getBuffer(), offset, value);
207242
}
208243
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,24 @@ class UnsafeRowWriterSuite extends SparkFunSuite {
5050
assert(res1 == res2)
5151
}
5252

53+
test("SPARK-26021: normalize float/double NaN and -0.0") {
54+
val unsafeRowWriter1 = new UnsafeRowWriter(4)
55+
unsafeRowWriter1.resetRowWriter()
56+
unsafeRowWriter1.write(0, Float.NaN)
57+
unsafeRowWriter1.write(1, Double.NaN)
58+
unsafeRowWriter1.write(2, 0.0f)
59+
unsafeRowWriter1.write(3, 0.0)
60+
val res1 = unsafeRowWriter1.getRow
61+
62+
val unsafeRowWriter2 = new UnsafeRowWriter(4)
63+
unsafeRowWriter2.resetRowWriter()
64+
unsafeRowWriter2.write(0, 0.0f/0.0f)
65+
unsafeRowWriter2.write(1, 0.0/0.0)
66+
unsafeRowWriter2.write(2, -0.0f)
67+
unsafeRowWriter2.write(3, -0.0)
68+
val res2 = unsafeRowWriter2.getRow
69+
70+
// The two rows should be the equal
71+
assert(res1 == res2)
72+
}
5373
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,4 +295,16 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
295295
df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan
296296
}
297297
}
298+
299+
test("NaN and -0.0 in join keys") {
300+
val df1 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d")
301+
val df2 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d")
302+
val joined = df1.join(df2, Seq("f", "d"))
303+
checkAnswer(joined, Seq(
304+
Row(Float.NaN, Double.NaN),
305+
Row(0.0f, 0.0),
306+
Row(0.0f, 0.0),
307+
Row(0.0f, 0.0),
308+
Row(0.0f, 0.0)))
309+
}
298310
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,4 +658,18 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
658658
|GROUP BY a
659659
|HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin))
660660
}
661+
662+
test("NaN and -0.0 in window partition keys") {
663+
val df = Seq(
664+
(Float.NaN, Double.NaN, 1),
665+
(0.0f/0.0f, 0.0/0.0, 1),
666+
(0.0f, 0.0, 1),
667+
(-0.0f, -0.0, 1)).toDF("f", "d", "i")
668+
val result = df.select($"f", count("i").over(Window.partitionBy("f", "d")))
669+
checkAnswer(result, Seq(
670+
Row(Float.NaN, 2),
671+
Row(Float.NaN, 2),
672+
Row(0.0f, 2),
673+
Row(0.0f, 2)))
674+
}
661675
}

0 commit comments

Comments
 (0)