Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,19 @@ private <S> void deserializeOperatorStateValues(
OperatorStateHandle.StateMetaInfo metaInfo)
throws IOException {

if (null != metaInfo) {
if (metaInfo != null) {
long[] offsets = metaInfo.getOffsets();
if (null != offsets) {
if (offsets != null) {
DataInputView div = new DataInputViewStreamWrapper(in);
TypeSerializer<S> serializer =
stateListForName.getStateMetaInfo().getPartitionStateSerializer();
long currentPos = in.getPos();
for (long offset : offsets) {
in.seek(offset);
if (currentPos != offset) {
in.seek(offset);
}
stateListForName.add(serializer.deserialize(div));
currentPos = in.getPos();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.core.fs.CloseableRegistry;
import org.apache.flink.core.fs.FSDataInputStream;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
Expand All @@ -31,6 +32,7 @@
import org.junit.jupiter.params.provider.ValueSource;
import org.testcontainers.utility.ThrowingFunction;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand All @@ -39,6 +41,8 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand Down Expand Up @@ -268,6 +272,39 @@ void testRepartitionOperatorState(boolean snapshotCompressionEnabled) throws Exc
}
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
void testRestoreAvoidsRedundantSeeksForSequentialOffsets(boolean snapshotCompressionEnabled)
throws Exception {
final ExecutionConfig cfg = new ExecutionConfig();
cfg.setUseSnapshotCompression(snapshotCompressionEnabled);
final ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend>
operatorStateBackendFactory =
createOperatorStateBackendFactory(
cfg, new CloseableRegistry(), this.getClass().getClassLoader());

final Map<String, List<String>> listStates = new HashMap<>();
final List<String> values =
IntStream.range(0, 100).mapToObj(idx -> "v" + idx).collect(Collectors.toList());
listStates.put("s1", values);

final OperatorStateHandle stateHandle =
createOperatorStateHandle(
operatorStateBackendFactory, listStates, Collections.emptyMap());

final AtomicInteger seekCount = new AtomicInteger();
final OperatorStateHandle stateHandleWithCountingSeeks =
wrapWithCountingSeekInputStream(stateHandle, seekCount);

verifyOperatorStateHandle(
operatorStateBackendFactory,
Collections.singletonList(stateHandleWithCountingSeeks),
listStates,
Collections.emptyMap());

assertThat(seekCount.get()).isLessThanOrEqualTo(1);
}

/**
* This is a simplified version of what RR partitioner does, so it only works in case there is
* no remainder.
Expand All @@ -286,4 +323,107 @@ private static Map<String, List<String>> getExpectedSplit(
}
return newStates;
}

private static OperatorStateHandle wrapWithCountingSeekInputStream(
OperatorStateHandle stateHandle, AtomicInteger seekCount) {
final StreamStateHandle delegate = stateHandle.getDelegateStateHandle();
final byte[] data =
delegate.asBytesIfInMemory()
.orElseThrow(
() ->
new IllegalStateException(
"Expected in-memory state handle for test."));
return new OperatorStreamStateHandle(
stateHandle.getStateNameToPartitionOffsets(),
new CountingStreamStateHandle(data, seekCount, delegate.getStreamStateHandleID()));
}

private static final class CountingStreamStateHandle implements StreamStateHandle {

private static final long serialVersionUID = 1L;

private final byte[] data;
private final AtomicInteger seekCount;
private final PhysicalStateHandleID streamStateHandleId;

private CountingStreamStateHandle(
byte[] data, AtomicInteger seekCount, PhysicalStateHandleID streamStateHandleId) {
this.data = data;
this.seekCount = seekCount;
this.streamStateHandleId = streamStateHandleId;
}

@Override
public FSDataInputStream openInputStream() {
return new CountingByteArrayFSDataInputStream(data, seekCount);
}

@Override
public Optional<byte[]> asBytesIfInMemory() {
return Optional.of(data);
}

@Override
public PhysicalStateHandleID getStreamStateHandleID() {
return streamStateHandleId;
}

@Override
public void discardState() {}

@Override
public long getStateSize() {
return data.length;
}

@Override
public void collectSizeStats(StateObjectSizeStatsCollector collector) {
collector.add(StateObjectLocation.LOCAL_MEMORY, getStateSize());
}
}

private static final class CountingByteArrayFSDataInputStream extends FSDataInputStream {

private final byte[] data;
private final AtomicInteger seekCount;
private int index;

private CountingByteArrayFSDataInputStream(byte[] data, AtomicInteger seekCount) {
this.data = data;
this.seekCount = seekCount;
}

@Override
public void seek(long desired) throws IOException {
seekCount.incrementAndGet();
if (desired >= 0 && desired <= data.length) {
index = (int) desired;
} else {
throw new IOException("position out of bounds");
}
}

@Override
public long getPos() {
return index;
}

@Override
public int read() {
return index < data.length ? data[index++] & 0xFF : -1;
}

@Override
public int read(byte[] b, int off, int len) {
final int bytesLeft = data.length - index;
if (bytesLeft > 0) {
final int bytesToCopy = Math.min(len, bytesLeft);
System.arraycopy(data, index, b, off, bytesToCopy);
index += bytesToCopy;
return bytesToCopy;
} else {
return len == 0 ? 0 : -1;
}
}
}
}