Skip to content

Commit

Permalink
Flink: SortKeySerializer and CompletedStatisticsSerializer support ve…
Browse files Browse the repository at this point in the history
…rsion and null sort key,and add UT for the change
  • Loading branch information
Guosmilesmile committed Dec 1, 2024
1 parent cbcf744 commit a49b87b
Show file tree
Hide file tree
Showing 13 changed files with 354 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,7 @@ public void processElement(StreamRecord<RowData> streamRecord) {
RowData record = streamRecord.getValue();
StructLike struct = rowDataWrapper.wrap(record);
sortKey.wrap(struct);
boolean containNull = false;
for (int i = 0; i < sortKey.size(); ++i) {
if (null == sortKey.get(i, Object.class)) {
containNull = true;
break;
}
}
if (!containNull) {
localStatistics.add(sortKey);
}
localStatistics.add(sortKey);

checkStatisticsTypeMigration();
output.collect(new StreamRecord<>(StatisticsOrRecord.fromRecord(record)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,7 @@ public void processElement(StreamRecord<RowData> streamRecord) {
RowData record = streamRecord.getValue();
StructLike struct = rowDataWrapper.wrap(record);
sortKey.wrap(struct);
boolean containNull = false;
for (int i = 0; i < sortKey.size(); ++i) {
if (null == sortKey.get(i, Object.class)) {
containNull = true;
break;
}
}
if (!containNull) {
localStatistics.add(sortKey);
}
localStatistics.add(sortKey);

checkStatisticsTypeMigration();
output.collect(new StreamRecord<>(StatisticsOrRecord.fromRecord(record)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,35 @@ class CompletedStatisticsSerializer extends TypeSerializer<CompletedStatistics>
private final MapSerializer<SortKey, Long> keyFrequencySerializer;
private final ListSerializer<SortKey> keySamplesSerializer;

private int sortKeySerializerVersion = -1;

CompletedStatisticsSerializer(TypeSerializer<SortKey> sortKeySerializer) {
this.sortKeySerializer = sortKeySerializer;
this.statisticsTypeSerializer = new EnumSerializer<>(StatisticsType.class);
this.keyFrequencySerializer = new MapSerializer<>(sortKeySerializer, LongSerializer.INSTANCE);
this.keySamplesSerializer = new ListSerializer<>(sortKeySerializer);
}

public void changeSortKeySerializerVersion(int version) {
if (sortKeySerializer instanceof SortKeySerializer) {
((SortKeySerializer) sortKeySerializer).setVersion(version);
this.sortKeySerializerVersion = version;
}
}

public void changeSortKeySerializerVersionLatest() {
if (sortKeySerializer instanceof SortKeySerializer) {
((SortKeySerializer) sortKeySerializer).restoreToLatestVersion();
}
}

public int getSortKeySerializerVersionLatest() {
if (sortKeySerializer instanceof SortKeySerializer) {
return ((SortKeySerializer) sortKeySerializer).getLatestVersion();
}
return sortKeySerializerVersion;
}

@Override
public boolean isImmutableType() {
return false;
Expand Down Expand Up @@ -82,6 +104,17 @@ public int getLength() {

@Override
public void serialize(CompletedStatistics record, DataOutputView target) throws IOException {
target.writeLong(record.checkpointId());
target.writeInt(getSortKeySerializerVersionLatest());
statisticsTypeSerializer.serialize(record.type(), target);
if (record.type() == StatisticsType.Map) {
keyFrequencySerializer.serialize(record.keyFrequency(), target);
} else {
keySamplesSerializer.serialize(Arrays.asList(record.keySamples()), target);
}
}

public void serializeV1(CompletedStatistics record, DataOutputView target) throws IOException {
target.writeLong(record.checkpointId());
statisticsTypeSerializer.serialize(record.type(), target);
if (record.type() == StatisticsType.Map) {
Expand All @@ -93,15 +126,35 @@ public void serialize(CompletedStatistics record, DataOutputView target) throws

@Override
public CompletedStatistics deserialize(DataInputView source) throws IOException {
long checkpointId = source.readLong();
changeSortKeySerializerVersion(source.readInt());
StatisticsType type = statisticsTypeSerializer.deserialize(source);
if (type == StatisticsType.Map) {
Map<SortKey, Long> keyFrequency = keyFrequencySerializer.deserialize(source);
changeSortKeySerializerVersionLatest();
return CompletedStatistics.fromKeyFrequency(checkpointId, keyFrequency);
} else {
List<SortKey> sortKeys = keySamplesSerializer.deserialize(source);
SortKey[] keySamples = new SortKey[sortKeys.size()];
keySamples = sortKeys.toArray(keySamples);
changeSortKeySerializerVersionLatest();
return CompletedStatistics.fromKeySamples(checkpointId, keySamples);
}
}

public CompletedStatistics deserializeV1(DataInputView source) throws IOException {
long checkpointId = source.readLong();
StatisticsType type = statisticsTypeSerializer.deserialize(source);
changeSortKeySerializerVersion(1);
if (type == StatisticsType.Map) {
Map<SortKey, Long> keyFrequency = keyFrequencySerializer.deserialize(source);
changeSortKeySerializerVersionLatest();
return CompletedStatistics.fromKeyFrequency(checkpointId, keyFrequency);
} else {
List<SortKey> sortKeys = keySamplesSerializer.deserialize(source);
SortKey[] keySamples = new SortKey[sortKeys.size()];
keySamples = sortKeys.toArray(keySamples);
changeSortKeySerializerVersionLatest();
return CompletedStatistics.fromKeySamples(checkpointId, keySamples);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ public void resetToCheckpoint(long checkpointId, byte[] checkpointData) {
"Restoring data statistic coordinator {} from checkpoint {}", operatorName, checkpointId);
this.completedStatistics =
StatisticsUtil.deserializeCompletedStatistics(
checkpointData, completedStatisticsSerializer);
checkpointData, (CompletedStatisticsSerializer) completedStatisticsSerializer);

// recompute global statistics in case downstream parallelism changed
this.globalStatistics =
globalStatistics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,7 @@ public void processElement(StreamRecord<RowData> streamRecord) {
RowData record = streamRecord.getValue();
StructLike struct = rowDataWrapper.wrap(record);
sortKey.wrap(struct);
boolean containNull = false;
for (int i = 0; i < sortKey.size(); ++i) {
if (null == sortKey.get(i, Object.class)) {
containNull = true;
break;
}
}
if (!containNull) {
localStatistics.add(sortKey);
}
localStatistics.add(sortKey);

checkStatisticsTypeMigration();
output.collect(new StreamRecord<>(StatisticsOrRecord.fromRecord(record)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,32 @@ class SortKeySerializer extends TypeSerializer<SortKey> {
private final int size;
private final Types.NestedField[] transformedFields;

private int version = SortKeySerializerSnapshot.CURRENT_VERSION;

private transient SortKey sortKey;

SortKeySerializer(Schema schema, SortOrder sortOrder, int version) {
this.version = version;
this.schema = schema;
this.sortOrder = sortOrder;
this.size = sortOrder.fields().size();

this.transformedFields = new Types.NestedField[size];
for (int i = 0; i < size; ++i) {
SortField sortField = sortOrder.fields().get(i);
Types.NestedField sourceField = schema.findField(sortField.sourceId());
Type resultType = sortField.transform().getResultType(sourceField.type());
Types.NestedField transformedField =
Types.NestedField.of(
sourceField.fieldId(),
sourceField.isOptional(),
sourceField.name(),
resultType,
sourceField.doc());
transformedFields[i] = transformedField;
}
}

SortKeySerializer(Schema schema, SortOrder sortOrder) {
this.schema = schema;
this.sortOrder = sortOrder;
Expand Down Expand Up @@ -83,6 +107,18 @@ private SortKey lazySortKey() {
return sortKey;
}

public int getLatestVersion() {
return snapshotConfiguration().getCurrentVersion();
}

public void restoreToLatestVersion() {
this.version = snapshotConfiguration().getCurrentVersion();
}

public void setVersion(int version) {
this.version = version;
}

@Override
public boolean isImmutableType() {
return false;
Expand Down Expand Up @@ -124,6 +160,16 @@ public void serialize(SortKey record, DataOutputView target) throws IOException
for (int i = 0; i < size; ++i) {
int fieldId = transformedFields[i].fieldId();
Type.TypeID typeId = transformedFields[i].type().typeId();
if (version > 1) {
Object value = record.get(i, Object.class);
if (value == null) {
target.writeBoolean(true);
continue;
} else {
target.writeBoolean(false);
}
}

switch (typeId) {
case BOOLEAN:
target.writeBoolean(record.get(i, Boolean.class));
Expand Down Expand Up @@ -192,6 +238,14 @@ public SortKey deserialize(SortKey reuse, DataInputView source) throws IOExcepti
reuse.size(),
size);
for (int i = 0; i < size; ++i) {
if (version > 1) {
boolean isNull = source.readBoolean();
if (isNull) {
reuse.set(i, null);
continue;
}
}

int fieldId = transformedFields[i].fieldId();
Type.TypeID typeId = transformedFields[i].type().typeId();
switch (typeId) {
Expand Down Expand Up @@ -276,11 +330,13 @@ public TypeSerializerSnapshot<SortKey> snapshotConfiguration() {
}

public static class SortKeySerializerSnapshot implements TypeSerializerSnapshot<SortKey> {
private static final int CURRENT_VERSION = 1;
private static final int CURRENT_VERSION = 2;

private Schema schema;
private SortOrder sortOrder;

private int version = CURRENT_VERSION;

/** Constructor for read instantiation. */
@SuppressWarnings({"unused", "checkstyle:RedundantModifier"})
public SortKeySerializerSnapshot() {
Expand Down Expand Up @@ -310,10 +366,16 @@ public void writeSnapshot(DataOutputView out) throws IOException {
@Override
public void readSnapshot(int readVersion, DataInputView in, ClassLoader userCodeClassLoader)
throws IOException {
if (readVersion == 1) {
readV1(in);
} else {
throw new IllegalArgumentException("Unknown read version: " + readVersion);
switch (readVersion) {
case 1:
readV1(in);
this.version = 1;
break;
case 2:
readV1(in);
break;
default:
throw new IllegalArgumentException("Unknown read version: " + readVersion);
}
}

Expand All @@ -324,6 +386,10 @@ public TypeSerializerSchemaCompatibility<SortKey> resolveSchemaCompatibility(
return TypeSerializerSchemaCompatibility.incompatible();
}

if (oldSerializerSnapshot.getCurrentVersion() == 1 && this.getCurrentVersion() == 2) {
return TypeSerializerSchemaCompatibility.compatibleAfterMigration();
}

// Sort order should be identical
SortKeySerializerSnapshot oldSnapshot = (SortKeySerializerSnapshot) oldSerializerSnapshot;
if (!sortOrder.sameOrder(oldSnapshot.sortOrder)) {
Expand All @@ -349,7 +415,7 @@ public TypeSerializerSchemaCompatibility<SortKey> resolveSchemaCompatibility(
public TypeSerializer<SortKey> restoreSerializer() {
Preconditions.checkState(schema != null, "Invalid schema: null");
Preconditions.checkState(sortOrder != null, "Invalid sort order: null");
return new SortKeySerializer(schema, sortOrder);
return new SortKeySerializer(schema, sortOrder, version);
}

private void readV1(DataInputView in) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,17 @@ static byte[] serializeCompletedStatistics(
}

static CompletedStatistics deserializeCompletedStatistics(
byte[] bytes, TypeSerializer<CompletedStatistics> statisticsSerializer) {
byte[] bytes, CompletedStatisticsSerializer statisticsSerializer) {
try {
DataInputDeserializer input = new DataInputDeserializer(bytes);
return statisticsSerializer.deserialize(input);
} catch (IOException e) {
throw new UncheckedIOException("Fail to deserialize aggregated statistics", e);
} catch (Exception e) {
try {
DataInputDeserializer input = new DataInputDeserializer(bytes);
return statisticsSerializer.deserializeV1(input);
} catch (IOException ioException) {
throw new UncheckedIOException("Fail to deserialize aggregated statistics", ioException);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.iceberg.flink.TestFixtures;
import org.apache.iceberg.flink.sink.shuffle.StatisticsType;
import org.apache.iceberg.flink.source.BoundedTestSource;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
Expand Down Expand Up @@ -252,6 +253,44 @@ public void testRangeDistributionWithoutSortOrderPartitioned() throws Exception
assertThat(snapshots).hasSizeGreaterThanOrEqualTo(numOfCheckpoints);
}

@TestTemplate
public void testRangeDistributionWithNullValue() throws Exception {
assumeThat(partitioned).isTrue();

table
.updateProperties()
.set(TableProperties.WRITE_DISTRIBUTION_MODE, DistributionMode.RANGE.modeName())
.commit();

int numOfCheckpoints = 6;
List<List<Row>> charRows = createCharRows(numOfCheckpoints, 10);
charRows.add(ImmutableList.of(Row.of(1, null)));
DataStream<Row> dataStream =
env.addSource(createRangeDistributionBoundedSource(charRows), ROW_TYPE_INFO);
FlinkSink.Builder builder =
FlinkSink.forRow(dataStream, SimpleDataUtil.FLINK_SCHEMA)
.table(table)
.tableLoader(tableLoader)
.writeParallelism(parallelism);

// sort based on partition columns
builder.append();
env.execute(getClass().getSimpleName());

table.refresh();
// ordered in reverse timeline from the newest snapshot to the oldest snapshot
List<Snapshot> snapshots = Lists.newArrayList(table.snapshots().iterator());
// only keep the snapshots with added data files
snapshots =
snapshots.stream()
.filter(snapshot -> snapshot.addedDataFiles(table.io()).iterator().hasNext())
.collect(Collectors.toList());

// Sometimes we will have more checkpoints than the bounded source if we pass the
// auto checkpoint interval. Thus producing multiple snapshots.
assertThat(snapshots).hasSizeGreaterThanOrEqualTo(numOfCheckpoints);
}

@TestTemplate
public void testRangeDistributionWithSortOrder() throws Exception {
table
Expand Down
Loading

0 comments on commit a49b87b

Please sign in to comment.