Skip to content

[SPARK-26021][SQL][followup] only deal with NaN and -0.0 in UnsafeWriter #23239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,6 @@ public static float getFloat(Object object, long offset) {
}

public static void putFloat(Object object, long offset, float value) {
if (Float.isNaN(value)) {
value = Float.NaN;
} else if (value == -0.0f) {
value = 0.0f;
}
Copy link
Member

@dongjoon-hyun dongjoon-hyun Dec 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These change are expected to cause the following test case failure in PlatformUtilSuite, but it seems to be missed. Could you fix the test case or remove it together, @cloud-fan ?

_UNSAFE.putFloat(object, offset, value);
}

Expand All @@ -187,11 +182,6 @@ public static double getDouble(Object object, long offset) {
}

public static void putDouble(Object object, long offset, double value) {
if (Double.isNaN(value)) {
value = Double.NaN;
} else if (value == -0.0d) {
value = 0.0d;
}
_UNSAFE.putDouble(object, offset, value);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,22 +157,4 @@ public void heapMemoryReuse() {
Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7);
Assert.assertEquals(obj3, onheap4.getBaseObject());
}

@Test
// SPARK-26021
public void writeMinusZeroIsReplacedWithZero() {
byte[] doubleBytes = new byte[Double.BYTES];
byte[] floatBytes = new byte[Float.BYTES];
Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d);
Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f);

byte[] doubleBytes2 = new byte[Double.BYTES];
byte[] floatBytes2 = new byte[Float.BYTES];
Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, 0.0d);
Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, 0.0f);

// Make sure the bytes we write from 0.0 and -0.0 are same.
Assert.assertArrayEquals(doubleBytes, doubleBytes2);
Assert.assertArrayEquals(floatBytes, floatBytes2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,46 @@ protected final void writeLong(long offset, long value) {
Platform.putLong(getBuffer(), offset, value);
}

// We need to take care of NaN and -0.0 in several places:
// 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be
// treated as same.
// 2. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong
// to the same group.
// 3. As join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be
// treated as same.
// 4. As window partition keys, different NaNs should be treated as same, `-0.0` and `0.0`
// should be treated as same.
//
// Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we
// recursively compare the fields/elements, so it's also fine.
//
// Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different
// NaNs have different binary representation, and the same thing happens for -0.0 and 0.0.
//
// Here we normalize NaN and -0.0, so that `UnsafeProjection` will normalize them when writing
// float/double columns and nested fields to `UnsafeRow`.
//
// Note that, we must do this for all the `UnsafeProjection`s, not only the ones that extract
// join/grouping/window partition keys. `UnsafeProjection` copies unsafe data directly for complex
// types, so nested float/double may not be normalized. We need to make sure that all the unsafe
// data(`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`) will have flat/double normalized during
// creation.
protected final void writeFloat(long offset, float value) {
if (Float.isNaN(value)) {
value = Float.NaN;
} else if (value == -0.0f) {
value = 0.0f;
}
Platform.putFloat(getBuffer(), offset, value);
}

// See comments for `writeFloat`.
protected final void writeDouble(long offset, double value) {
if (Double.isNaN(value)) {
value = Double.NaN;
} else if (value == -0.0d) {
value = 0.0d;
}
Platform.putDouble(getBuffer(), offset, value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,24 @@ class UnsafeRowWriterSuite extends SparkFunSuite {
assert(res1 == res2)
}

test("SPARK-26021: normalize float/double NaN and -0.0") {
val unsafeRowWriter1 = new UnsafeRowWriter(4)
unsafeRowWriter1.resetRowWriter()
unsafeRowWriter1.write(0, Float.NaN)
unsafeRowWriter1.write(1, Double.NaN)
unsafeRowWriter1.write(2, 0.0f)
unsafeRowWriter1.write(3, 0.0)
val res1 = unsafeRowWriter1.getRow

val unsafeRowWriter2 = new UnsafeRowWriter(4)
unsafeRowWriter2.resetRowWriter()
unsafeRowWriter2.write(0, 0.0f/0.0f)
unsafeRowWriter2.write(1, 0.0/0.0)
unsafeRowWriter2.write(2, -0.0f)
unsafeRowWriter2.write(3, -0.0)
val res2 = unsafeRowWriter2.getRow

// The two rows should be the equal
assert(res1 == res2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -295,4 +295,16 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan
}
}

test("NaN and -0.0 in join keys") {
val df1 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d")
val df2 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d")
val joined = df1.join(df2, Seq("f", "d"))
checkAnswer(joined, Seq(
Row(Float.NaN, Double.NaN),
Row(0.0f, 0.0),
Row(0.0f, 0.0),
Row(0.0f, 0.0),
Row(0.0f, 0.0)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -681,4 +681,18 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
Row("S2", "P2", 300, 300, 500)))

}

test("NaN and -0.0 in window partition keys") {
val df = Seq(
(Float.NaN, Double.NaN, 1),
(0.0f/0.0f, 0.0/0.0, 1),
(0.0f, 0.0, 1),
(-0.0f, -0.0, 1)).toDF("f", "d", "i")
val result = df.select($"f", count("i").over(Window.partitionBy("f", "d")))
checkAnswer(result, Seq(
Row(Float.NaN, 2),
Row(Float.NaN, 2),
Row(0.0f, 2),
Row(0.0f, 2)))
}
}