Skip to content

Commit d63ab5a

Browse files
Alon Doroncloud-fan
Alon Doron
authored andcommitted
[SPARK-26021][SQL] replace minus zero with zero in Platform.putDouble/Float
GROUP BY treats -0.0 and 0.0 as different values which is unlike hive's behavior. In addition current behavior with codegen is unpredictable (see example in JIRA ticket). ## What changes were proposed in this pull request? In Platform.putDouble/Float() checking if the value is -0.0, and if so replacing with 0.0. This is used by UnsafeRow so it won't have -0.0 values. ## How was this patch tested? Added tests Closes #23043 from adoron/adoron-spark-26021-replace-minus-zero-with-zero. Authored-by: Alon Doron <adoron@palantir.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit 0ec7b99) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 8705a9d commit d63ab5a

File tree

6 files changed

+42
-13
lines changed

6 files changed

+42
-13
lines changed

common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ public static float getFloat(Object object, long offset) {
120120
}
121121

122122
public static void putFloat(Object object, long offset, float value) {
123+
if (Float.isNaN(value)) {
124+
value = Float.NaN;
125+
} else if (value == -0.0f) {
126+
value = 0.0f;
127+
}
123128
_UNSAFE.putFloat(object, offset, value);
124129
}
125130

@@ -128,6 +133,11 @@ public static double getDouble(Object object, long offset) {
128133
}
129134

130135
public static void putDouble(Object object, long offset, double value) {
136+
if (Double.isNaN(value)) {
137+
value = Double.NaN;
138+
} else if (value == -0.0d) {
139+
value = 0.0d;
140+
}
131141
_UNSAFE.putDouble(object, offset, value);
132142
}
133143

common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,18 @@ public void heapMemoryReuse() {
157157
Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7);
158158
Assert.assertEquals(obj3, onheap4.getBaseObject());
159159
}
160+
161+
@Test
162+
// SPARK-26021
163+
public void writeMinusZeroIsReplacedWithZero() {
164+
byte[] doubleBytes = new byte[Double.BYTES];
165+
byte[] floatBytes = new byte[Float.BYTES];
166+
Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d);
167+
Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f);
168+
double doubleFromPlatform = Platform.getDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET);
169+
float floatFromPlatform = Platform.getFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET);
170+
171+
Assert.assertEquals(Double.doubleToLongBits(0.0d), Double.doubleToLongBits(doubleFromPlatform));
172+
Assert.assertEquals(Float.floatToIntBits(0.0f), Float.floatToIntBits(floatFromPlatform));
173+
}
160174
}

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,6 @@ public void setLong(int ordinal, long value) {
224224
public void setDouble(int ordinal, double value) {
225225
assertIndexIsValid(ordinal);
226226
setNotNullAt(ordinal);
227-
if (Double.isNaN(value)) {
228-
value = Double.NaN;
229-
}
230227
Platform.putDouble(baseObject, getFieldOffset(ordinal), value);
231228
}
232229

@@ -255,9 +252,6 @@ public void setByte(int ordinal, byte value) {
255252
public void setFloat(int ordinal, float value) {
256253
assertIndexIsValid(ordinal);
257254
setNotNullAt(ordinal);
258-
if (Float.isNaN(value)) {
259-
value = Float.NaN;
260-
}
261255
Platform.putFloat(baseObject, getFieldOffset(ordinal), value);
262256
}
263257

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,16 +199,10 @@ protected final void writeLong(long offset, long value) {
199199
}
200200

201201
protected final void writeFloat(long offset, float value) {
202-
if (Float.isNaN(value)) {
203-
value = Float.NaN;
204-
}
205202
Platform.putFloat(getBuffer(), offset, value);
206203
}
207204

208205
protected final void writeDouble(long offset, double value) {
209-
if (Double.isNaN(value)) {
210-
value = Double.NaN;
211-
}
212206
Platform.putDouble(getBuffer(), offset, value);
213207
}
214208
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,4 +727,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
727727
"grouping expressions: [current_date(None)], value: [key: int, value: string], " +
728728
"type: GroupBy]"))
729729
}
730+
731+
test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping") {
732+
val colName = "i"
733+
val doubles = Seq(0.0d, -0.0d, 0.0d).toDF(colName).groupBy(colName).count().collect()
734+
val floats = Seq(0.0f, -0.0f, 0.0f).toDF(colName).groupBy(colName).count().collect()
735+
736+
assert(doubles.length == 1)
737+
assert(floats.length == 1)
738+
// using compare since 0.0 == -0.0 is true
739+
assert(java.lang.Double.compare(doubles(0).getDouble(0), 0.0d) == 0)
740+
assert(java.lang.Float.compare(floats(0).getFloat(0), 0.0f) == 0)
741+
assert(doubles(0).getLong(1) == 3)
742+
assert(floats(0).getLong(1) == 3)
743+
}
730744
}

sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ object QueryTest {
289289
def prepareRow(row: Row): Row = {
290290
Row.fromSeq(row.toSeq.map {
291291
case null => null
292-
case d: java.math.BigDecimal => BigDecimal(d)
292+
case bd: java.math.BigDecimal => BigDecimal(bd)
293293
// Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+
294294
case seq: Seq[_] => seq.map {
295295
case b: java.lang.Byte => b.byteValue
@@ -303,6 +303,9 @@ object QueryTest {
303303
// Convert array to Seq for easy equality check.
304304
case b: Array[_] => b.toSeq
305305
case r: Row => prepareRow(r)
306+
// spark treats -0.0 as 0.0
307+
case d: Double if d == -0.0d => 0.0d
308+
case f: Float if f == -0.0f => 0.0f
306309
case o => o
307310
})
308311
}

0 commit comments

Comments
 (0)