Skip to content

Commit 7f94c4b

Browse files
committed
fix Float16 statistics handling for NaN and zero values
1 parent 6c7cefd commit 7f94c4b

File tree

5 files changed

+87
-28
lines changed

5 files changed

+87
-28
lines changed

parquet-column/src/main/java/org/apache/parquet/column/statistics/Statistics.java

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,6 @@ public Statistics<?> build() {
142142

143143
// Builder for FLOAT16 type to handle special cases of min/max values like NaN, -0.0, and 0.0
144144
private static class Float16Builder extends Builder {
145-
private static final Binary POSITIVE_ZERO_LITTLE_ENDIAN = Binary.fromConstantByteArray(new byte[] {0x00, 0x00});
146-
private static final Binary NEGATIVE_ZERO_LITTLE_ENDIAN =
147-
Binary.fromConstantByteArray(new byte[] {0x00, (byte) 0x80});
148-
149145
public Float16Builder(PrimitiveType type) {
150146
super(type);
151147
assert type.getPrimitiveTypeName() == PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY;
@@ -162,15 +158,17 @@ public Statistics<?> build() {
162158
short max = bMax.get2BytesLittleEndian();
163159
// Drop min/max values in case of NaN as the sorting order of values is undefined for this case
164160
if (Float16.isNaN(min) || Float16.isNaN(max)) {
165-
stats.setMinMax(POSITIVE_ZERO_LITTLE_ENDIAN, NEGATIVE_ZERO_LITTLE_ENDIAN);
161+
stats.setMinMax(Float16.POSITIVE_ZERO_LITTLE_ENDIAN, Float16.POSITIVE_ZERO_LITTLE_ENDIAN);
166162
((Statistics<?>) stats).hasNonNullValue = false;
167163
} else {
168164
// Updating min to -0.0 and max to +0.0 to ensure that no 0.0 values would be skipped
169165
if (min == (short) 0x0000) {
170-
stats.setMinMax(NEGATIVE_ZERO_LITTLE_ENDIAN, bMax);
166+
bMin = Float16.NEGATIVE_ZERO_LITTLE_ENDIAN;
167+
stats.setMinMax(bMin, bMax);
171168
}
172169
if (max == (short) 0x8000) {
173-
stats.setMinMax(bMin, POSITIVE_ZERO_LITTLE_ENDIAN);
170+
bMax = Float16.POSITIVE_ZERO_LITTLE_ENDIAN;
171+
stats.setMinMax(bMin, bMax);
174172
}
175173
}
176174
}

parquet-column/src/main/java/org/apache/parquet/internal/column/columnindex/BinaryColumnIndexBuilder.java

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import java.util.List;
2424
import org.apache.parquet.filter2.predicate.Statistics;
2525
import org.apache.parquet.io.api.Binary;
26+
import org.apache.parquet.schema.Float16;
27+
import org.apache.parquet.schema.LogicalTypeAnnotation;
2628
import org.apache.parquet.schema.PrimitiveComparator;
2729
import org.apache.parquet.schema.PrimitiveType;
2830

@@ -82,6 +84,8 @@ int compareValueToMax(int arrayIndex) {
8284
private final List<Binary> maxValues = new ArrayList<>();
8385
private final BinaryTruncator truncator;
8486
private final int truncateLength;
87+
private final boolean isFloat16;
88+
private boolean invalid;
8589

8690
private static Binary convert(ByteBuffer buffer) {
8791
return Binary.fromReusedByteBuffer(buffer);
@@ -94,6 +98,7 @@ private static ByteBuffer convert(Binary value) {
9498
BinaryColumnIndexBuilder(PrimitiveType type, int truncateLength) {
9599
truncator = BinaryTruncator.getTruncator(type);
96100
this.truncateLength = truncateLength;
101+
this.isFloat16 = type.getLogicalTypeAnnotation() instanceof LogicalTypeAnnotation.Float16LogicalTypeAnnotation;
97102
}
98103

99104
@Override
@@ -104,12 +109,43 @@ void addMinMaxFromBytes(ByteBuffer min, ByteBuffer max) {
104109

105110
@Override
106111
void addMinMax(Object min, Object max) {
107-
minValues.add(min == null ? null : truncator.truncateMin((Binary) min, truncateLength));
108-
maxValues.add(max == null ? null : truncator.truncateMax((Binary) max, truncateLength));
112+
Binary bMin = (Binary) min;
113+
Binary bMax = (Binary) max;
114+
115+
if (isFloat16 && bMin != null && bMax != null) {
116+
if (bMin.length() != LogicalTypeAnnotation.Float16LogicalTypeAnnotation.BYTES
117+
|| bMax.length() != LogicalTypeAnnotation.Float16LogicalTypeAnnotation.BYTES) {
118+
// Should not happen for Float16
119+
invalid = true;
120+
} else {
121+
short sMin = bMin.get2BytesLittleEndian();
122+
short sMax = bMax.get2BytesLittleEndian();
123+
124+
if (Float16.isNaN(sMin) || Float16.isNaN(sMax)) {
125+
invalid = true;
126+
}
127+
128+
// Sorting order is undefined for -0.0 so let min = -0.0 and max = +0.0 to
129+
// ensure that no 0.0 values are skipped
130+
// +0.0 is 0x0000, -0.0 is 0x8000 (little endian: 00 00, 00 80)
131+
if (sMin == (short) 0x0000) {
132+
bMin = Float16.NEGATIVE_ZERO_LITTLE_ENDIAN;
133+
}
134+
if (sMax == (short) 0x8000) {
135+
bMax = Float16.POSITIVE_ZERO_LITTLE_ENDIAN;
136+
}
137+
}
138+
}
139+
140+
minValues.add(bMin == null ? null : truncator.truncateMin(bMin, truncateLength));
141+
maxValues.add(bMax == null ? null : truncator.truncateMax(bMax, truncateLength));
109142
}
110143

111144
@Override
112145
ColumnIndexBase<Binary> createColumnIndex(PrimitiveType type) {
146+
if (invalid) {
147+
return null;
148+
}
113149
BinaryColumnIndex columnIndex = new BinaryColumnIndex(type);
114150
columnIndex.minValues = minValues.toArray(new Binary[0]);
115151
columnIndex.maxValues = maxValues.toArray(new Binary[0]);

parquet-column/src/main/java/org/apache/parquet/schema/Float16.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@
4646
* Ref: https://android.googlesource.com/platform/libcore/+/master/luni/src/main/java/libcore/util/FP16.java
4747
*/
4848
public class Float16 {
49+
// Positive zero of type half-precision float.
50+
public static final Binary POSITIVE_ZERO_LITTLE_ENDIAN = Binary.fromConstantByteArray(new byte[] {0x00, 0x00});
51+
// Negative zero of type half-precision float.
52+
public static final Binary NEGATIVE_ZERO_LITTLE_ENDIAN =
53+
Binary.fromConstantByteArray(new byte[] {0x00, (byte) 0x80});
54+
4955
// Positive infinity of type half-precision float.
5056
private static final short POSITIVE_INFINITY = (short) 0x7c00;
5157
// A Not-a-Number representation of a half-precision float.

parquet-hadoop/src/test/java/org/apache/parquet/statistics/TestFloat16ReadWriteRoundTrip.java

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -122,37 +122,25 @@ public class TestFloat16ReadWriteRoundTrip {
122122
Binary.fromConstantByteArray(new byte[] {(byte) 0x00, (byte) 0x7c})
123123
}; // Infinity
124124

125-
private Binary[] valuesAllPositiveZeroMinMax = {
126-
Binary.fromConstantByteArray(new byte[] {(byte) 0x00, (byte) 0x00}), // +0
125+
private Binary[] valuesAllZeroMinMax = {
126+
Binary.fromConstantByteArray(new byte[] {(byte) 0x00, (byte) 0x80}), // -0
127127
Binary.fromConstantByteArray(new byte[] {(byte) 0x00, (byte) 0x00})
128128
}; // +0
129129

130-
private Binary[] valuesAllNegativeZeroMinMax = {
131-
Binary.fromConstantByteArray(new byte[] {(byte) 0x00, (byte) 0x80}), // -0
132-
Binary.fromConstantByteArray(new byte[] {(byte) 0x00, (byte) 0x80})
133-
}; // -0
134-
135-
private Binary[] valuesWithNaNMinMax = {
136-
Binary.fromConstantByteArray(new byte[] {(byte) 0x00, (byte) 0xc0}), // -2.0
137-
Binary.fromConstantByteArray(new byte[] {(byte) 0x00, (byte) 0x7e})
138-
}; // NaN
139-
140130
@Test
141131
public void testFloat16ColumnIndex() throws IOException {
142132
List<Binary[]> testValues = List.of(
143133
valuesInAscendingOrder,
144134
valuesInDescendingOrder,
145135
valuesUndefinedOrder,
146136
valuesAllPositiveZero,
147-
valuesAllNegativeZero,
148-
valuesWithNaN);
137+
valuesAllNegativeZero);
149138
List<Binary[]> expectedValues = List.of(
150139
valuesInAscendingOrderMinMax,
151140
valuesInDescendingOrderMinMax,
152141
valuesUndefinedOrderMinMax,
153-
valuesAllPositiveZeroMinMax,
154-
valuesAllNegativeZeroMinMax,
155-
valuesWithNaNMinMax);
142+
valuesAllZeroMinMax,
143+
valuesAllZeroMinMax);
156144

157145
for (int i = 0; i < testValues.size(); i++) {
158146
MessageType schema = Types.buildMessage()
@@ -187,6 +175,37 @@ public void testFloat16ColumnIndex() throws IOException {
187175
}
188176
}
189177

178+
@Test
179+
public void testFloat16NanColumnIndex() throws IOException {
180+
MessageType schema = Types.buildMessage()
181+
.required(FIXED_LEN_BYTE_ARRAY)
182+
.as(float16Type())
183+
.length(2)
184+
.named("col_float16")
185+
.named("msg");
186+
187+
Configuration conf = new Configuration();
188+
GroupWriteSupport.setSchema(schema, conf);
189+
GroupFactory factory = new SimpleGroupFactory(schema);
190+
Path path = newTempPath();
191+
try (ParquetWriter<Group> writer = ExampleParquetWriter.builder(path)
192+
.withConf(conf)
193+
.withDictionaryEncoding(false)
194+
.build()) {
195+
196+
for (Binary value : valuesWithNaN) {
197+
writer.write(factory.newGroup().append("col_float16", value));
198+
}
199+
}
200+
201+
try (ParquetFileReader reader = ParquetFileReader.open(HadoopInputFile.fromPath(path, new Configuration()))) {
202+
ColumnChunkMetaData column =
203+
reader.getFooter().getBlocks().get(0).getColumns().get(0);
204+
ColumnIndex index = reader.readColumnIndex(column);
205+
assertEquals(index, null);
206+
}
207+
}
208+
190209
private Path newTempPath() throws IOException {
191210
File file = temp.newFile();
192211
Preconditions.checkArgument(file.delete(), "Could not remove temp file");

parquet-hadoop/src/test/java/org/apache/parquet/statistics/TestFloat16Statistics.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ public class TestFloat16Statistics {
135135
// Float16Builder: Drop min/max values in case of NaN as the sorting order of values is undefined
136136
private Binary[] valuesWithNaNStatsMinMax = {
137137
Binary.fromConstantByteArray(new byte[] {(byte) 0x00, (byte) 0x00}), // +0
138-
Binary.fromConstantByteArray(new byte[] {(byte) 0x00, (byte) 0x80})
139-
}; // -0
138+
Binary.fromConstantByteArray(new byte[] {(byte) 0x00, (byte) 0x00})
139+
}; // +0
140140

141141
@Test
142142
public void testFloat16StatisticsMultipleCases() throws IOException {

0 commit comments

Comments
 (0)