Skip to content

Commit

Permalink
Flink: Backport default values support in Parquet reader on Flink v1.…
Browse files Browse the repository at this point in the history
…18 and v1.19 (apache#12072)
  • Loading branch information
jbonofre authored Jan 23, 2025
1 parent 908bdc3 commit 17bda20
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public ParquetValueReader<RowData> message(
}

@Override
@SuppressWarnings("checkstyle:CyclomaticComplexity")
public ParquetValueReader<RowData> struct(
Types.StructType expected, GroupType struct, List<ParquetValueReader<?>> fieldReaders) {
// match the expected struct's order
Expand Down Expand Up @@ -120,6 +121,7 @@ public ParquetValueReader<RowData> struct(
int defaultMaxDefinitionLevel = type.getMaxDefinitionLevel(currentPath());
for (Types.NestedField field : expectedFields) {
int id = field.fieldId();
ParquetValueReader<?> reader = readersById.get(id);
if (idToConstant.containsKey(id)) {
// containsKey is used because the constant may be null
int fieldMaxDefinitionLevel =
Expand All @@ -133,15 +135,21 @@ public ParquetValueReader<RowData> struct(
} else if (id == MetadataColumns.IS_DELETED.fieldId()) {
reorderedFields.add(ParquetValueReaders.constant(false));
types.add(null);
} else if (reader != null) {
reorderedFields.add(reader);
types.add(typesById.get(id));
} else if (field.initialDefault() != null) {
reorderedFields.add(
ParquetValueReaders.constant(
RowDataUtil.convertConstant(field.type(), field.initialDefault()),
maxDefinitionLevelsById.getOrDefault(id, defaultMaxDefinitionLevel)));
types.add(typesById.get(id));
} else if (field.isOptional()) {
reorderedFields.add(ParquetValueReaders.nulls());
types.add(null);
} else {
ParquetValueReader<?> reader = readersById.get(id);
if (reader != null) {
reorderedFields.add(reader);
types.add(typesById.get(id));
} else {
reorderedFields.add(ParquetValueReaders.nulls());
types.add(null);
}
throw new IllegalArgumentException(
String.format("Missing required field: %s", field.name()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@
public class TestFlinkParquetReader extends DataTest {
private static final int NUM_RECORDS = 100;

@Override
protected boolean supportsDefaultValues() {
return true;
}

@Test
public void testBuildReader() {
MessageType fileSchema =
Expand Down Expand Up @@ -199,41 +204,50 @@ public void testTwoLevelList() throws IOException {
}
}

private void writeAndValidate(Iterable<Record> iterable, Schema schema) throws IOException {
private void writeAndValidate(
Iterable<Record> iterable, Schema writeSchema, Schema expectedSchema) throws IOException {
File testFile = File.createTempFile("junit", null, temp.toFile());
assertThat(testFile.delete()).isTrue();

try (FileAppender<Record> writer =
Parquet.write(Files.localOutput(testFile))
.schema(schema)
.schema(writeSchema)
.createWriterFunc(GenericParquetWriter::buildWriter)
.build()) {
writer.addAll(iterable);
}

try (CloseableIterable<RowData> reader =
Parquet.read(Files.localInput(testFile))
.project(schema)
.createReaderFunc(type -> FlinkParquetReaders.buildReader(schema, type))
.project(expectedSchema)
.createReaderFunc(type -> FlinkParquetReaders.buildReader(expectedSchema, type))
.build()) {
Iterator<Record> expected = iterable.iterator();
Iterator<RowData> rows = reader.iterator();
LogicalType rowType = FlinkSchemaUtil.convert(schema);
LogicalType rowType = FlinkSchemaUtil.convert(writeSchema);
for (int i = 0; i < NUM_RECORDS; i += 1) {
assertThat(rows).hasNext();
TestHelpers.assertRowData(schema.asStruct(), rowType, expected.next(), rows.next());
TestHelpers.assertRowData(writeSchema.asStruct(), rowType, expected.next(), rows.next());
}
assertThat(rows).isExhausted();
}
}

@Override
protected void writeAndValidate(Schema schema) throws IOException {
writeAndValidate(RandomGenericData.generate(schema, NUM_RECORDS, 19981), schema);
writeAndValidate(RandomGenericData.generate(schema, NUM_RECORDS, 19981), schema, schema);
writeAndValidate(
RandomGenericData.generateDictionaryEncodableRecords(schema, NUM_RECORDS, 21124), schema);
RandomGenericData.generateDictionaryEncodableRecords(schema, NUM_RECORDS, 21124),
schema,
schema);
writeAndValidate(
RandomGenericData.generateFallbackRecords(schema, NUM_RECORDS, 21124, NUM_RECORDS / 20),
schema,
schema);
}

@Override
protected void writeAndValidate(Schema writeSchema, Schema expectedSchema) throws IOException {
writeAndValidate(RandomGenericData.generate(writeSchema, 100, 0L), writeSchema, expectedSchema);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public ParquetValueReader<RowData> message(
}

@Override
@SuppressWarnings("checkstyle:CyclomaticComplexity")
public ParquetValueReader<RowData> struct(
Types.StructType expected, GroupType struct, List<ParquetValueReader<?>> fieldReaders) {
// match the expected struct's order
Expand Down Expand Up @@ -120,6 +121,7 @@ public ParquetValueReader<RowData> struct(
int defaultMaxDefinitionLevel = type.getMaxDefinitionLevel(currentPath());
for (Types.NestedField field : expectedFields) {
int id = field.fieldId();
ParquetValueReader<?> reader = readersById.get(id);
if (idToConstant.containsKey(id)) {
// containsKey is used because the constant may be null
int fieldMaxDefinitionLevel =
Expand All @@ -133,15 +135,21 @@ public ParquetValueReader<RowData> struct(
} else if (id == MetadataColumns.IS_DELETED.fieldId()) {
reorderedFields.add(ParquetValueReaders.constant(false));
types.add(null);
} else if (reader != null) {
reorderedFields.add(reader);
types.add(typesById.get(id));
} else if (field.initialDefault() != null) {
reorderedFields.add(
ParquetValueReaders.constant(
RowDataUtil.convertConstant(field.type(), field.initialDefault()),
maxDefinitionLevelsById.getOrDefault(id, defaultMaxDefinitionLevel)));
types.add(typesById.get(id));
} else if (field.isOptional()) {
reorderedFields.add(ParquetValueReaders.nulls());
types.add(null);
} else {
ParquetValueReader<?> reader = readersById.get(id);
if (reader != null) {
reorderedFields.add(reader);
types.add(typesById.get(id));
} else {
reorderedFields.add(ParquetValueReaders.nulls());
types.add(null);
}
throw new IllegalArgumentException(
String.format("Missing required field: %s", field.name()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@
public class TestFlinkParquetReader extends DataTest {
private static final int NUM_RECORDS = 100;

@Override
protected boolean supportsDefaultValues() {
return true;
}

@Test
public void testBuildReader() {
MessageType fileSchema =
Expand Down Expand Up @@ -199,41 +204,50 @@ public void testTwoLevelList() throws IOException {
}
}

private void writeAndValidate(Iterable<Record> iterable, Schema schema) throws IOException {
private void writeAndValidate(
Iterable<Record> iterable, Schema writeSchema, Schema expectedSchema) throws IOException {
File testFile = File.createTempFile("junit", null, temp.toFile());
assertThat(testFile.delete()).isTrue();

try (FileAppender<Record> writer =
Parquet.write(Files.localOutput(testFile))
.schema(schema)
.schema(writeSchema)
.createWriterFunc(GenericParquetWriter::buildWriter)
.build()) {
writer.addAll(iterable);
}

try (CloseableIterable<RowData> reader =
Parquet.read(Files.localInput(testFile))
.project(schema)
.createReaderFunc(type -> FlinkParquetReaders.buildReader(schema, type))
.project(expectedSchema)
.createReaderFunc(type -> FlinkParquetReaders.buildReader(expectedSchema, type))
.build()) {
Iterator<Record> expected = iterable.iterator();
Iterator<RowData> rows = reader.iterator();
LogicalType rowType = FlinkSchemaUtil.convert(schema);
LogicalType rowType = FlinkSchemaUtil.convert(writeSchema);
for (int i = 0; i < NUM_RECORDS; i += 1) {
assertThat(rows).hasNext();
TestHelpers.assertRowData(schema.asStruct(), rowType, expected.next(), rows.next());
TestHelpers.assertRowData(writeSchema.asStruct(), rowType, expected.next(), rows.next());
}
assertThat(rows).isExhausted();
}
}

@Override
protected void writeAndValidate(Schema schema) throws IOException {
writeAndValidate(RandomGenericData.generate(schema, NUM_RECORDS, 19981), schema);
writeAndValidate(RandomGenericData.generate(schema, NUM_RECORDS, 19981), schema, schema);
writeAndValidate(
RandomGenericData.generateDictionaryEncodableRecords(schema, NUM_RECORDS, 21124), schema);
RandomGenericData.generateDictionaryEncodableRecords(schema, NUM_RECORDS, 21124),
schema,
schema);
writeAndValidate(
RandomGenericData.generateFallbackRecords(schema, NUM_RECORDS, 21124, NUM_RECORDS / 20),
schema,
schema);
}

@Override
protected void writeAndValidate(Schema writeSchema, Schema expectedSchema) throws IOException {
writeAndValidate(RandomGenericData.generate(writeSchema, 100, 0L), writeSchema, expectedSchema);
}
}

0 comments on commit 17bda20

Please sign in to comment.