Skip to content

Commit

Permalink
Flink: SortKeySerializer and CompletedStatisticsSerializer support ve…
Browse files Browse the repository at this point in the history
…rsion,and add UT for the change
  • Loading branch information
Guosmilesmile committed Dec 1, 2024
1 parent 627982a commit cd3d93d
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,35 @@ class CompletedStatisticsSerializer extends TypeSerializer<CompletedStatistics>
private final MapSerializer<SortKey, Long> keyFrequencySerializer;
private final ListSerializer<SortKey> keySamplesSerializer;

private int sortKeySerializerVersion = -1;

CompletedStatisticsSerializer(TypeSerializer<SortKey> sortKeySerializer) {
this.sortKeySerializer = sortKeySerializer;
this.statisticsTypeSerializer = new EnumSerializer<>(StatisticsType.class);
this.keyFrequencySerializer = new MapSerializer<>(sortKeySerializer, LongSerializer.INSTANCE);
this.keySamplesSerializer = new ListSerializer<>(sortKeySerializer);
}

public void changeSortKeySerializerVersion(int version) {
if (sortKeySerializer instanceof SortKeySerializer) {
((SortKeySerializer) sortKeySerializer).setVersion(version);
this.sortKeySerializerVersion = version;
}
}

public void changeSortKeySerializerVersionLatest() {
if (sortKeySerializer instanceof SortKeySerializer) {
((SortKeySerializer) sortKeySerializer).restoreToLatestVersion();
}
}

public int getSortKeySerializerVersionLatest() {
if (sortKeySerializer instanceof SortKeySerializer) {
return ((SortKeySerializer) sortKeySerializer).getLatestVersion();
}
return sortKeySerializerVersion;
}

@Override
public boolean isImmutableType() {
return false;
Expand Down Expand Up @@ -82,6 +104,17 @@ public int getLength() {

@Override
public void serialize(CompletedStatistics record, DataOutputView target) throws IOException {
target.writeLong(record.checkpointId());
target.writeInt(getSortKeySerializerVersionLatest());
statisticsTypeSerializer.serialize(record.type(), target);
if (record.type() == StatisticsType.Map) {
keyFrequencySerializer.serialize(record.keyFrequency(), target);
} else {
keySamplesSerializer.serialize(Arrays.asList(record.keySamples()), target);
}
}

public void serializeV1(CompletedStatistics record, DataOutputView target) throws IOException {
target.writeLong(record.checkpointId());
statisticsTypeSerializer.serialize(record.type(), target);
if (record.type() == StatisticsType.Map) {
Expand All @@ -93,15 +126,35 @@ public void serialize(CompletedStatistics record, DataOutputView target) throws

@Override
public CompletedStatistics deserialize(DataInputView source) throws IOException {
long checkpointId = source.readLong();
changeSortKeySerializerVersion(source.readInt());
StatisticsType type = statisticsTypeSerializer.deserialize(source);
if (type == StatisticsType.Map) {
Map<SortKey, Long> keyFrequency = keyFrequencySerializer.deserialize(source);
changeSortKeySerializerVersionLatest();
return CompletedStatistics.fromKeyFrequency(checkpointId, keyFrequency);
} else {
List<SortKey> sortKeys = keySamplesSerializer.deserialize(source);
SortKey[] keySamples = new SortKey[sortKeys.size()];
keySamples = sortKeys.toArray(keySamples);
changeSortKeySerializerVersionLatest();
return CompletedStatistics.fromKeySamples(checkpointId, keySamples);
}
}

public CompletedStatistics deserializeV1(DataInputView source) throws IOException {
long checkpointId = source.readLong();
StatisticsType type = statisticsTypeSerializer.deserialize(source);
changeSortKeySerializerVersion(1);
if (type == StatisticsType.Map) {
Map<SortKey, Long> keyFrequency = keyFrequencySerializer.deserialize(source);
changeSortKeySerializerVersionLatest();
return CompletedStatistics.fromKeyFrequency(checkpointId, keyFrequency);
} else {
List<SortKey> sortKeys = keySamplesSerializer.deserialize(source);
SortKey[] keySamples = new SortKey[sortKeys.size()];
keySamples = sortKeys.toArray(keySamples);
changeSortKeySerializerVersionLatest();
return CompletedStatistics.fromKeySamples(checkpointId, keySamples);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ public void resetToCheckpoint(long checkpointId, byte[] checkpointData) {
"Restoring data statistic coordinator {} from checkpoint {}", operatorName, checkpointId);
this.completedStatistics =
StatisticsUtil.deserializeCompletedStatistics(
checkpointData, completedStatisticsSerializer);
checkpointData, (CompletedStatisticsSerializer) completedStatisticsSerializer);

// recompute global statistics in case downstream parallelism changed
this.globalStatistics =
globalStatistics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,32 @@ class SortKeySerializer extends TypeSerializer<SortKey> {
private final int size;
private final Types.NestedField[] transformedFields;

private int version = SortKeySerializerSnapshot.CURRENT_VERSION;

private transient SortKey sortKey;

SortKeySerializer(Schema schema, SortOrder sortOrder, int version) {
this.version = version;
this.schema = schema;
this.sortOrder = sortOrder;
this.size = sortOrder.fields().size();

this.transformedFields = new Types.NestedField[size];
for (int i = 0; i < size; ++i) {
SortField sortField = sortOrder.fields().get(i);
Types.NestedField sourceField = schema.findField(sortField.sourceId());
Type resultType = sortField.transform().getResultType(sourceField.type());
Types.NestedField transformedField =
Types.NestedField.of(
sourceField.fieldId(),
sourceField.isOptional(),
sourceField.name(),
resultType,
sourceField.doc());
transformedFields[i] = transformedField;
}
}

SortKeySerializer(Schema schema, SortOrder sortOrder) {
this.schema = schema;
this.sortOrder = sortOrder;
Expand Down Expand Up @@ -83,6 +107,18 @@ private SortKey lazySortKey() {
return sortKey;
}

public int getLatestVersion() {
return snapshotConfiguration().getCurrentVersion();
}

public void restoreToLatestVersion() {
this.version = snapshotConfiguration().getCurrentVersion();
}

public void setVersion(int version) {
this.version = version;
}

@Override
public boolean isImmutableType() {
return false;
Expand Down Expand Up @@ -124,12 +160,14 @@ public void serialize(SortKey record, DataOutputView target) throws IOException
for (int i = 0; i < size; ++i) {
int fieldId = transformedFields[i].fieldId();
Type.TypeID typeId = transformedFields[i].type().typeId();
Object value = record.get(i, Object.class);
if (value == null) {
target.writeBoolean(true);
continue;
} else {
target.writeBoolean(false);
if (version > 1) {
Object value = record.get(i, Object.class);
if (value == null) {
target.writeBoolean(true);
continue;
} else {
target.writeBoolean(false);
}
}

switch (typeId) {
Expand Down Expand Up @@ -200,10 +238,12 @@ public SortKey deserialize(SortKey reuse, DataInputView source) throws IOExcepti
reuse.size(),
size);
for (int i = 0; i < size; ++i) {
boolean isNull = source.readBoolean();
if (isNull) {
reuse.set(i, null);
continue;
if (version > 1) {
boolean isNull = source.readBoolean();
if (isNull) {
reuse.set(i, null);
continue;
}
}

int fieldId = transformedFields[i].fieldId();
Expand Down Expand Up @@ -295,6 +335,8 @@ public static class SortKeySerializerSnapshot implements TypeSerializerSnapshot<
private Schema schema;
private SortOrder sortOrder;

private int version = CURRENT_VERSION;

/** Constructor for read instantiation. */
@SuppressWarnings({"unused", "checkstyle:RedundantModifier"})
public SortKeySerializerSnapshot() {
Expand Down Expand Up @@ -326,11 +368,11 @@ public void readSnapshot(int readVersion, DataInputView in, ClassLoader userCode
throws IOException {
switch (readVersion) {
case 1:
throw new UnsupportedOperationException(
String.format(
"No longer supported version [%s]. Please upgrade first . ", readVersion));
readV1(in);
this.version = 1;
break;
case 2:
readV2(in);
readV1(in);
break;
default:
throw new IllegalArgumentException("Unknown read version: " + readVersion);
Expand All @@ -344,8 +386,8 @@ public TypeSerializerSchemaCompatibility<SortKey> resolveSchemaCompatibility(
return TypeSerializerSchemaCompatibility.incompatible();
}

if (oldSerializerSnapshot.getCurrentVersion() != this.getCurrentVersion()) {
return TypeSerializerSchemaCompatibility.incompatible();
if (oldSerializerSnapshot.getCurrentVersion() == 1 && this.getCurrentVersion() == 2) {
return TypeSerializerSchemaCompatibility.compatibleAfterMigration();
}

// Sort order should be identical
Expand Down Expand Up @@ -373,10 +415,10 @@ public TypeSerializerSchemaCompatibility<SortKey> resolveSchemaCompatibility(
public TypeSerializer<SortKey> restoreSerializer() {
Preconditions.checkState(schema != null, "Invalid schema: null");
Preconditions.checkState(sortOrder != null, "Invalid sort order: null");
return new SortKeySerializer(schema, sortOrder);
return new SortKeySerializer(schema, sortOrder, version);
}

private void readV2(DataInputView in) throws IOException {
private void readV1(DataInputView in) throws IOException {
String schemaJson = StringUtils.readString(in);
String sortOrderJson = StringUtils.readString(in);
this.schema = SchemaParser.fromJson(schemaJson);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,17 @@ static byte[] serializeCompletedStatistics(
}

static CompletedStatistics deserializeCompletedStatistics(
byte[] bytes, TypeSerializer<CompletedStatistics> statisticsSerializer) {
byte[] bytes, CompletedStatisticsSerializer statisticsSerializer) {
try {
DataInputDeserializer input = new DataInputDeserializer(bytes);
return statisticsSerializer.deserialize(input);
} catch (IOException e) {
throw new UncheckedIOException("Fail to deserialize aggregated statistics", e);
} catch (Exception e) {
try {
DataInputDeserializer input = new DataInputDeserializer(bytes);
return statisticsSerializer.deserializeV1(input);
} catch (IOException ioException) {
throw new UncheckedIOException("Fail to deserialize aggregated statistics", ioException);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@
package org.apache.iceberg.flink.sink.shuffle;

import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;

import org.apache.flink.api.common.typeutils.SerializerTestBase;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataInputDeserializer;
import org.apache.flink.core.memory.DataOutputSerializer;
import org.apache.iceberg.SortKey;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.junit.jupiter.api.Test;

public class TestCompletedStatisticsSerializer extends SerializerTestBase<CompletedStatistics> {

Expand Down Expand Up @@ -51,4 +55,49 @@ protected CompletedStatistics[] getTestData() {
CompletedStatistics.fromKeySamples(2L, new SortKey[] {CHAR_KEYS.get("a"), CHAR_KEYS.get("b")})
};
}

@Test
public void testSerializer() throws Exception {
TypeSerializer<CompletedStatistics> completedStatisticsTypeSerializer = createSerializer();
CompletedStatistics[] data = getTestData();
DataOutputSerializer output = new DataOutputSerializer(1024);
completedStatisticsTypeSerializer.serialize(data[0], output);
byte[] serializedBytes = output.getCopyOfBuffer();

DataInputDeserializer input = new DataInputDeserializer(serializedBytes);
CompletedStatistics deserialized = completedStatisticsTypeSerializer.deserialize(input);
assertThat(deserialized).isEqualTo(data[0]);
}

@Test
public void testRestoreOldVersionSerializer() throws Exception {
CompletedStatisticsSerializer completedStatisticsTypeSerializer =
(CompletedStatisticsSerializer) createSerializer();
completedStatisticsTypeSerializer.changeSortKeySerializerVersion(1);
CompletedStatistics[] data = getTestData();
DataOutputSerializer output = new DataOutputSerializer(1024);
completedStatisticsTypeSerializer.serializeV1(data[0], output);
byte[] serializedBytes = output.getCopyOfBuffer();

completedStatisticsTypeSerializer.changeSortKeySerializerVersionLatest();
CompletedStatistics completedStatistics =
StatisticsUtil.deserializeCompletedStatistics(
serializedBytes, completedStatisticsTypeSerializer);
assertThat(completedStatistics).isEqualTo(data[0]);
}

@Test
public void testRestoreNewSerializer() throws Exception {
CompletedStatisticsSerializer completedStatisticsTypeSerializer =
(CompletedStatisticsSerializer) createSerializer();
CompletedStatistics[] data = getTestData();
DataOutputSerializer output = new DataOutputSerializer(1024);
completedStatisticsTypeSerializer.serialize(data[0], output);
byte[] serializedBytes = output.getCopyOfBuffer();

CompletedStatistics completedStatistics =
StatisticsUtil.deserializeCompletedStatistics(
serializedBytes, completedStatisticsTypeSerializer);
assertThat(completedStatistics).isEqualTo(data[0]);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import static org.apache.iceberg.flink.sink.shuffle.Fixtures.SCHEMA;
import static org.apache.iceberg.flink.sink.shuffle.Fixtures.SORT_KEY;
import static org.apache.iceberg.flink.sink.shuffle.Fixtures.SORT_ORDER;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;

import java.io.IOException;
Expand Down Expand Up @@ -82,10 +81,18 @@ public void testRestoredOldSerializer() throws Exception {
SortKey sortKey = SORT_KEY.copy();
sortKey.wrap(struct);

SortKeySerializer originalSerializer = new SortKeySerializer(SCHEMA, SORT_ORDER);
assertThatThrownBy(() -> roundTripOldVersion(originalSerializer.snapshotConfiguration()))
.isInstanceOf(UnsupportedOperationException.class)
.hasMessageContaining("No longer supported version ");
SortKeySerializer originalSerializer = new SortKeySerializer(SCHEMA, SORT_ORDER, 1);
TypeSerializerSnapshot<SortKey> snapshot =
roundTrip(originalSerializer.snapshotConfiguration());
TypeSerializer<SortKey> restoredSerializer = snapshot.restoreSerializer();
((SortKeySerializer) restoredSerializer).setVersion(1);
DataOutputSerializer output = new DataOutputSerializer(1024);
originalSerializer.serialize(sortKey, output);
byte[] serializedBytes = output.getCopyOfBuffer();

DataInputDeserializer input = new DataInputDeserializer(serializedBytes);
SortKey deserialized = restoredSerializer.deserialize(input);
assertThat(deserialized).isEqualTo(sortKey);
}

@Test
Expand Down Expand Up @@ -225,19 +232,4 @@ private static SortKeySerializer.SortKeySerializerSnapshot roundTrip(
restored.readSnapshot(restored.getCurrentVersion(), in, original.getClass().getClassLoader());
return restored;
}

/** Copied from Flink {@code AvroSerializerSnapshotTest} */
private static SortKeySerializer.SortKeySerializerSnapshot roundTripOldVersion(
TypeSerializerSnapshot<SortKey> original) throws IOException {
// writeSnapshot();
DataOutputSerializer out = new DataOutputSerializer(1024);
original.writeSnapshot(out);
// init
SortKeySerializer.SortKeySerializerSnapshot restored =
new SortKeySerializer.SortKeySerializerSnapshot();
// readSnapshot();
DataInputView in = new DataInputDeserializer(out.wrapAsByteBuffer());
restored.readSnapshot(1, in, original.getClass().getClassLoader());
return restored;
}
}

0 comments on commit cd3d93d

Please sign in to comment.