Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flink: Fix range distribution npe when value is null #11662

Merged
merged 9 commits into from
Dec 10, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,7 @@ public void processElement(StreamRecord<RowData> streamRecord) {
RowData record = streamRecord.getValue();
StructLike struct = rowDataWrapper.wrap(record);
sortKey.wrap(struct);
boolean containNull = false;
for (int i = 0; i < sortKey.size(); ++i) {
if (null == sortKey.get(i, Object.class)) {
containNull = true;
break;
}
}
if (!containNull) {
localStatistics.add(sortKey);
}
localStatistics.add(sortKey);

checkStatisticsTypeMigration();
output.collect(new StreamRecord<>(StatisticsOrRecord.fromRecord(record)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,7 @@ public void processElement(StreamRecord<RowData> streamRecord) {
RowData record = streamRecord.getValue();
StructLike struct = rowDataWrapper.wrap(record);
sortKey.wrap(struct);
boolean containNull = false;
for (int i = 0; i < sortKey.size(); ++i) {
if (null == sortKey.get(i, Object.class)) {
containNull = true;
break;
}
}
if (!containNull) {
localStatistics.add(sortKey);
}
localStatistics.add(sortKey);

checkStatisticsTypeMigration();
output.collect(new StreamRecord<>(StatisticsOrRecord.fromRecord(record)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,35 @@ class CompletedStatisticsSerializer extends TypeSerializer<CompletedStatistics>
private final MapSerializer<SortKey, Long> keyFrequencySerializer;
private final ListSerializer<SortKey> keySamplesSerializer;

private int sortKeySerializerVersion = -1;

CompletedStatisticsSerializer(TypeSerializer<SortKey> sortKeySerializer) {
this.sortKeySerializer = sortKeySerializer;
this.statisticsTypeSerializer = new EnumSerializer<>(StatisticsType.class);
this.keyFrequencySerializer = new MapSerializer<>(sortKeySerializer, LongSerializer.INSTANCE);
this.keySamplesSerializer = new ListSerializer<>(sortKeySerializer);
}

public void changeSortKeySerializerVersion(int version) {
if (sortKeySerializer instanceof SortKeySerializer) {
((SortKeySerializer) sortKeySerializer).setVersion(version);
this.sortKeySerializerVersion = version;
}
}

public void changeSortKeySerializerVersionLatest() {
if (sortKeySerializer instanceof SortKeySerializer) {
((SortKeySerializer) sortKeySerializer).restoreToLatestVersion();
}
}

public int getSortKeySerializerVersionLatest() {
if (sortKeySerializer instanceof SortKeySerializer) {
return ((SortKeySerializer) sortKeySerializer).getLatestVersion();
}
return sortKeySerializerVersion;
}

@Override
public boolean isImmutableType() {
return false;
Expand Down Expand Up @@ -82,6 +104,17 @@ public int getLength() {

@Override
public void serialize(CompletedStatistics record, DataOutputView target) throws IOException {
target.writeLong(record.checkpointId());
target.writeInt(getSortKeySerializerVersionLatest());
statisticsTypeSerializer.serialize(record.type(), target);
if (record.type() == StatisticsType.Map) {
keyFrequencySerializer.serialize(record.keyFrequency(), target);
} else {
keySamplesSerializer.serialize(Arrays.asList(record.keySamples()), target);
}
}

public void serializeV1(CompletedStatistics record, DataOutputView target) throws IOException {
target.writeLong(record.checkpointId());
statisticsTypeSerializer.serialize(record.type(), target);
if (record.type() == StatisticsType.Map) {
Expand All @@ -93,15 +126,35 @@ public void serialize(CompletedStatistics record, DataOutputView target) throws

@Override
public CompletedStatistics deserialize(DataInputView source) throws IOException {
long checkpointId = source.readLong();
changeSortKeySerializerVersion(source.readInt());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK We were not writing the serializer version before this change. So not sure this works when restoring the serializer to read old data.

StatisticsType type = statisticsTypeSerializer.deserialize(source);
if (type == StatisticsType.Map) {
Map<SortKey, Long> keyFrequency = keyFrequencySerializer.deserialize(source);
changeSortKeySerializerVersionLatest();
return CompletedStatistics.fromKeyFrequency(checkpointId, keyFrequency);
} else {
List<SortKey> sortKeys = keySamplesSerializer.deserialize(source);
SortKey[] keySamples = new SortKey[sortKeys.size()];
keySamples = sortKeys.toArray(keySamples);
changeSortKeySerializerVersionLatest();
return CompletedStatistics.fromKeySamples(checkpointId, keySamples);
}
}

public CompletedStatistics deserializeV1(DataInputView source) throws IOException {
long checkpointId = source.readLong();
StatisticsType type = statisticsTypeSerializer.deserialize(source);
changeSortKeySerializerVersion(1);
if (type == StatisticsType.Map) {
Map<SortKey, Long> keyFrequency = keyFrequencySerializer.deserialize(source);
changeSortKeySerializerVersionLatest();
return CompletedStatistics.fromKeyFrequency(checkpointId, keyFrequency);
} else {
List<SortKey> sortKeys = keySamplesSerializer.deserialize(source);
SortKey[] keySamples = new SortKey[sortKeys.size()];
keySamples = sortKeys.toArray(keySamples);
changeSortKeySerializerVersionLatest();
return CompletedStatistics.fromKeySamples(checkpointId, keySamples);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ public void resetToCheckpoint(long checkpointId, byte[] checkpointData) {
"Restoring data statistic coordinator {} from checkpoint {}", operatorName, checkpointId);
this.completedStatistics =
StatisticsUtil.deserializeCompletedStatistics(
checkpointData, completedStatisticsSerializer);
checkpointData, (CompletedStatisticsSerializer) completedStatisticsSerializer);

// recompute global statistics in case downstream parallelism changed
this.globalStatistics =
globalStatistics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,7 @@ public void processElement(StreamRecord<RowData> streamRecord) {
RowData record = streamRecord.getValue();
StructLike struct = rowDataWrapper.wrap(record);
sortKey.wrap(struct);
boolean containNull = false;
for (int i = 0; i < sortKey.size(); ++i) {
if (null == sortKey.get(i, Object.class)) {
containNull = true;
break;
}
}
if (!containNull) {
localStatistics.add(sortKey);
}
localStatistics.add(sortKey);

checkStatisticsTypeMigration();
output.collect(new StreamRecord<>(StatisticsOrRecord.fromRecord(record)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,32 @@ class SortKeySerializer extends TypeSerializer<SortKey> {
private final int size;
private final Types.NestedField[] transformedFields;

private int version = SortKeySerializerSnapshot.CURRENT_VERSION;

private transient SortKey sortKey;

SortKeySerializer(Schema schema, SortOrder sortOrder, int version) {
this.version = version;
this.schema = schema;
this.sortOrder = sortOrder;
this.size = sortOrder.fields().size();

this.transformedFields = new Types.NestedField[size];
for (int i = 0; i < size; ++i) {
SortField sortField = sortOrder.fields().get(i);
Types.NestedField sourceField = schema.findField(sortField.sourceId());
Type resultType = sortField.transform().getResultType(sourceField.type());
Types.NestedField transformedField =
Types.NestedField.of(
sourceField.fieldId(),
sourceField.isOptional(),
sourceField.name(),
resultType,
sourceField.doc());
transformedFields[i] = transformedField;
}
Guosmilesmile marked this conversation as resolved.
Show resolved Hide resolved
}

SortKeySerializer(Schema schema, SortOrder sortOrder) {
this.schema = schema;
this.sortOrder = sortOrder;
Expand Down Expand Up @@ -83,6 +107,18 @@ private SortKey lazySortKey() {
return sortKey;
}

public int getLatestVersion() {
return snapshotConfiguration().getCurrentVersion();
}

public void restoreToLatestVersion() {
this.version = snapshotConfiguration().getCurrentVersion();
}

public void setVersion(int version) {
this.version = version;
}

@Override
public boolean isImmutableType() {
return false;
Expand Down Expand Up @@ -124,6 +160,16 @@ public void serialize(SortKey record, DataOutputView target) throws IOException
for (int i = 0; i < size; ++i) {
int fieldId = transformedFields[i].fieldId();
Type.TypeID typeId = transformedFields[i].type().typeId();
if (version > 1) {
Object value = record.get(i, Object.class);
if (value == null) {
target.writeBoolean(true);
continue;
} else {
target.writeBoolean(false);
}
}
pvary marked this conversation as resolved.
Show resolved Hide resolved

switch (typeId) {
case BOOLEAN:
target.writeBoolean(record.get(i, Boolean.class));
Expand Down Expand Up @@ -192,6 +238,14 @@ public SortKey deserialize(SortKey reuse, DataInputView source) throws IOExcepti
reuse.size(),
size);
for (int i = 0; i < size; ++i) {
if (version > 1) {
boolean isNull = source.readBoolean();
if (isNull) {
reuse.set(i, null);
continue;
}
}
pvary marked this conversation as resolved.
Show resolved Hide resolved

int fieldId = transformedFields[i].fieldId();
Type.TypeID typeId = transformedFields[i].type().typeId();
switch (typeId) {
Expand Down Expand Up @@ -276,11 +330,13 @@ public TypeSerializerSnapshot<SortKey> snapshotConfiguration() {
}

public static class SortKeySerializerSnapshot implements TypeSerializerSnapshot<SortKey> {
private static final int CURRENT_VERSION = 1;
private static final int CURRENT_VERSION = 2;

private Schema schema;
private SortOrder sortOrder;

private int version = CURRENT_VERSION;

/** Constructor for read instantiation. */
@SuppressWarnings({"unused", "checkstyle:RedundantModifier"})
public SortKeySerializerSnapshot() {
Expand Down Expand Up @@ -310,10 +366,16 @@ public void writeSnapshot(DataOutputView out) throws IOException {
@Override
public void readSnapshot(int readVersion, DataInputView in, ClassLoader userCodeClassLoader)
throws IOException {
if (readVersion == 1) {
readV1(in);
} else {
throw new IllegalArgumentException("Unknown read version: " + readVersion);
switch (readVersion) {
case 1:
readV1(in);
this.version = 1;
break;
case 2:
readV1(in);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove the version suffix in the method name? It handles both versions.

break;
default:
throw new IllegalArgumentException("Unknown read version: " + readVersion);
}
}

Expand All @@ -324,6 +386,10 @@ public TypeSerializerSchemaCompatibility<SortKey> resolveSchemaCompatibility(
return TypeSerializerSchemaCompatibility.incompatible();
}

if (oldSerializerSnapshot.getCurrentVersion() == 1 && this.getCurrentVersion() == 2) {
return TypeSerializerSchemaCompatibility.compatibleAfterMigration();
}

// Sort order should be identical
SortKeySerializerSnapshot oldSnapshot = (SortKeySerializerSnapshot) oldSerializerSnapshot;
if (!sortOrder.sameOrder(oldSnapshot.sortOrder)) {
Expand All @@ -349,7 +415,7 @@ public TypeSerializerSchemaCompatibility<SortKey> resolveSchemaCompatibility(
public TypeSerializer<SortKey> restoreSerializer() {
Preconditions.checkState(schema != null, "Invalid schema: null");
Preconditions.checkState(sortOrder != null, "Invalid sort order: null");
return new SortKeySerializer(schema, sortOrder);
return new SortKeySerializer(schema, sortOrder, version);
}

private void readV1(DataInputView in) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,17 @@ static byte[] serializeCompletedStatistics(
}

static CompletedStatistics deserializeCompletedStatistics(
byte[] bytes, TypeSerializer<CompletedStatistics> statisticsSerializer) {
byte[] bytes, CompletedStatisticsSerializer statisticsSerializer) {
try {
DataInputDeserializer input = new DataInputDeserializer(bytes);
return statisticsSerializer.deserialize(input);
} catch (IOException e) {
throw new UncheckedIOException("Fail to deserialize aggregated statistics", e);
} catch (Exception e) {
try {
DataInputDeserializer input = new DataInputDeserializer(bytes);
return statisticsSerializer.deserializeV1(input);
} catch (IOException ioException) {
throw new UncheckedIOException("Fail to deserialize aggregated statistics", ioException);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, we're retrying here in case the restore fails.

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.iceberg.flink.TestFixtures;
import org.apache.iceberg.flink.sink.shuffle.StatisticsType;
import org.apache.iceberg.flink.source.BoundedTestSource;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
Expand Down Expand Up @@ -252,6 +253,44 @@ public void testRangeDistributionWithoutSortOrderPartitioned() throws Exception
assertThat(snapshots).hasSizeGreaterThanOrEqualTo(numOfCheckpoints);
}

@TestTemplate
public void testRangeDistributionWithNullValue() throws Exception {
assumeThat(partitioned).isTrue();

table
.updateProperties()
.set(TableProperties.WRITE_DISTRIBUTION_MODE, DistributionMode.RANGE.modeName())
.commit();

int numOfCheckpoints = 6;
List<List<Row>> charRows = createCharRows(numOfCheckpoints, 10);
charRows.add(ImmutableList.of(Row.of(1, null)));
DataStream<Row> dataStream =
env.addSource(createRangeDistributionBoundedSource(charRows), ROW_TYPE_INFO);
FlinkSink.Builder builder =
FlinkSink.forRow(dataStream, SimpleDataUtil.FLINK_SCHEMA)
.table(table)
.tableLoader(tableLoader)
.writeParallelism(parallelism);

// sort based on partition columns
builder.append();
env.execute(getClass().getSimpleName());

table.refresh();
// ordered in reverse timeline from the newest snapshot to the oldest snapshot
List<Snapshot> snapshots = Lists.newArrayList(table.snapshots().iterator());
// only keep the snapshots with added data files
snapshots =
snapshots.stream()
.filter(snapshot -> snapshot.addedDataFiles(table.io()).iterator().hasNext())
.collect(Collectors.toList());

// Sometimes we will have more checkpoints than the bounded source if we pass the
// auto checkpoint interval. Thus producing multiple snapshots.
assertThat(snapshots).hasSizeGreaterThanOrEqualTo(numOfCheckpoints);
}

@TestTemplate
public void testRangeDistributionWithSortOrder() throws Exception {
table
Expand Down
Loading