Skip to content

Commit

Permalink
[FLINK-31963][state] Fix rescaling bug in recovery from unaligned che…
Browse files Browse the repository at this point in the history
…ckpoints. (#22584) (#22594)

This commit fixes problems in StateAssignmentOperation for unaligned checkpoints with stateless operators that have upstream operators with output partition state or downstream operators with input channel state.

(cherry picked from commit 354c0f4)
  • Loading branch information
StefanRRichter authored May 17, 2023
1 parent 5fdcfbb commit 8d8a486
Show file tree
Hide file tree
Showing 6 changed files with 335 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,24 @@ public void assignStates() {

// repartition state
for (TaskStateAssignment stateAssignment : vertexAssignments.values()) {
if (stateAssignment.hasNonFinishedState) {
if (stateAssignment.hasNonFinishedState
// FLINK-31963: We need to run repartitioning for stateless operators that have
// upstream output or downstream input states.
|| stateAssignment.hasUpstreamOutputStates()
|| stateAssignment.hasDownstreamInputStates()) {
assignAttemptState(stateAssignment);
}
}

// actually assign the state
for (TaskStateAssignment stateAssignment : vertexAssignments.values()) {
// If upstream has output states, even the empty task state should be assigned for the
// current task in order to notify this task that the old states will send to it which
// likely should be filtered.
// If upstream has output states or downstream has input states, even the empty task
// state should be assigned for the current task in order to notify this task that the
// old states will send to it which likely should be filtered.
if (stateAssignment.hasNonFinishedState
|| stateAssignment.isFullyFinished
|| stateAssignment.hasUpstreamOutputStates()) {
|| stateAssignment.hasUpstreamOutputStates()
|| stateAssignment.hasDownstreamInputStates()) {
assignTaskStateToExecutionJobVertices(stateAssignment);
}
}
Expand Down Expand Up @@ -345,9 +350,10 @@ public static <T extends StateObject> void reDistributePartitionableStates(
newParallelism)));
}

public <I, T extends AbstractChannelStateHandle<I>> void reDistributeResultSubpartitionStates(
TaskStateAssignment assignment) {
if (!assignment.hasOutputState) {
public void reDistributeResultSubpartitionStates(TaskStateAssignment assignment) {
// FLINK-31963: We can skip this phase if there is no output state AND downstream has no
// input states
if (!assignment.hasOutputState && !assignment.hasDownstreamInputStates()) {
return;
}

Expand Down Expand Up @@ -394,7 +400,9 @@ public <I, T extends AbstractChannelStateHandle<I>> void reDistributeResultSubpa
}

public void reDistributeInputChannelStates(TaskStateAssignment stateAssignment) {
if (!stateAssignment.hasInputState) {
// FLINK-31963: We can skip this phase only if there is no input state AND upstream has no
// output states
if (!stateAssignment.hasInputState && !stateAssignment.hasUpstreamOutputStates()) {
return;
}

Expand Down Expand Up @@ -435,7 +443,7 @@ public void reDistributeInputChannelStates(TaskStateAssignment stateAssignment)
: getPartitionState(
inputOperatorState, InputChannelInfo::getGateIdx, gateIndex);
final MappingBasedRepartitioner<InputChannelStateHandle> repartitioner =
new MappingBasedRepartitioner(mapping);
new MappingBasedRepartitioner<>(mapping);
final Map<OperatorInstanceID, List<InputChannelStateHandle>> repartitioned =
applyRepartitioner(
stateAssignment.inputOperatorID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class TaskStateAssignment {

@Nullable private TaskStateAssignment[] downstreamAssignments;
@Nullable private TaskStateAssignment[] upstreamAssignments;
@Nullable private Boolean hasUpstreamOutputStates;
@Nullable private Boolean hasDownstreamInputStates;

private final Map<IntermediateDataSetID, TaskStateAssignment> consumerAssignment;
private final Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments;
Expand Down Expand Up @@ -202,8 +204,21 @@ public OperatorSubtaskState getSubtaskState(OperatorInstanceID instanceID) {
}

public boolean hasUpstreamOutputStates() {
return Arrays.stream(getUpstreamAssignments())
.anyMatch(assignment -> assignment.hasOutputState);
if (hasUpstreamOutputStates == null) {
hasUpstreamOutputStates =
Arrays.stream(getUpstreamAssignments())
.anyMatch(assignment -> assignment.hasOutputState);
}
return hasUpstreamOutputStates;
}

public boolean hasDownstreamInputStates() {
if (hasDownstreamInputStates == null) {
hasDownstreamInputStates =
Arrays.stream(getDownstreamAssignments())
.anyMatch(assignment -> assignment.hasInputState);
}
return hasDownstreamInputStates;
}

private InflightDataGateOrPartitionRescalingDescriptor log(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor;
import org.apache.flink.runtime.client.JobExecutionException;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
Expand Down Expand Up @@ -51,6 +52,9 @@
import org.junit.ClassRule;
import org.junit.Test;

import javax.annotation.Nullable;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumMap;
Expand Down Expand Up @@ -82,6 +86,7 @@
import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ARBITRARY;
import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.RANGE;
import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ROUND_ROBIN;
import static org.apache.flink.util.Preconditions.checkArgument;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsInAnyOrder;
Expand Down Expand Up @@ -785,6 +790,129 @@ public void testOnlyUpstreamChannelStateAssignment()
}
}

/** FLINK-31963: Tests rescaling for stateless operators and upstream result partition state. */
@Test
public void testOnlyUpstreamChannelRescaleStateAssignment()
throws JobException, JobExecutionException {
Random random = new Random();
OperatorSubtaskState upstreamOpState =
OperatorSubtaskState.builder()
.setResultSubpartitionState(
new StateObjectCollection<>(
asList(
createNewResultSubpartitionStateHandle(10, random),
createNewResultSubpartitionStateHandle(
10, random))))
.build();
testOnlyUpstreamOrDownstreamRescalingInternal(upstreamOpState, null, 5, 7);
}

/** FLINK-31963: Tests rescaling for stateless operators and downstream input channel state. */
@Test
public void testOnlyDownstreamChannelRescaleStateAssignment()
throws JobException, JobExecutionException {
Random random = new Random();
OperatorSubtaskState downstreamOpState =
OperatorSubtaskState.builder()
.setInputChannelState(
new StateObjectCollection<>(
asList(
createNewInputChannelStateHandle(10, random),
createNewInputChannelStateHandle(10, random))))
.build();
testOnlyUpstreamOrDownstreamRescalingInternal(null, downstreamOpState, 5, 5);
}

private void testOnlyUpstreamOrDownstreamRescalingInternal(
@Nullable OperatorSubtaskState upstreamOpState,
@Nullable OperatorSubtaskState downstreamOpState,
int expectedUpstreamCount,
int expectedDownstreamCount)
throws JobException, JobExecutionException {

checkArgument(
upstreamOpState != downstreamOpState
&& (upstreamOpState == null || downstreamOpState == null),
"Either upstream or downstream state must exist, but not both");

// Start from parallelism 5 for both operators
int upstreamParallelism = 5;
int downstreamParallelism = 5;

// Build states
List<OperatorID> operatorIds = buildOperatorIds(2);
Map<OperatorID, OperatorState> states = new HashMap<>();
OperatorState upstreamState =
new OperatorState(operatorIds.get(0), upstreamParallelism, MAX_P);
OperatorState downstreamState =
new OperatorState(operatorIds.get(1), downstreamParallelism, MAX_P);

states.put(operatorIds.get(0), upstreamState);
states.put(operatorIds.get(1), downstreamState);

if (upstreamOpState != null) {
upstreamState.putState(0, upstreamOpState);
// rescale downstream 5 -> 3
downstreamParallelism = 3;
}

if (downstreamOpState != null) {
downstreamState.putState(0, downstreamOpState);
// rescale upstream 5 -> 3
upstreamParallelism = 3;
}

List<OperatorIdWithParallelism> opIdWithParallelism = new ArrayList<>(2);
opIdWithParallelism.add(
new OperatorIdWithParallelism(operatorIds.get(0), upstreamParallelism));
opIdWithParallelism.add(
new OperatorIdWithParallelism(operatorIds.get(1), downstreamParallelism));

Map<OperatorID, ExecutionJobVertex> vertices =
buildVertices(opIdWithParallelism, RANGE, ROUND_ROBIN);

// Run state assignment
new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false)
.assignStates();

// Check results
ExecutionJobVertex upstreamExecutionJobVertex = vertices.get(operatorIds.get(0));
ExecutionJobVertex downstreamExecutionJobVertex = vertices.get(operatorIds.get(1));

List<TaskStateSnapshot> upstreamTaskStateSnapshots =
getTaskStateSnapshotFromVertex(upstreamExecutionJobVertex);
List<TaskStateSnapshot> downstreamTaskStateSnapshots =
getTaskStateSnapshotFromVertex(downstreamExecutionJobVertex);

checkMappings(
upstreamTaskStateSnapshots,
TaskStateSnapshot::getOutputRescalingDescriptor,
expectedUpstreamCount);

checkMappings(
downstreamTaskStateSnapshots,
TaskStateSnapshot::getInputRescalingDescriptor,
expectedDownstreamCount);
}

private void checkMappings(
List<TaskStateSnapshot> taskStateSnapshots,
Function<TaskStateSnapshot, InflightDataRescalingDescriptor> extractFun,
int expectedCount) {
Assert.assertEquals(
expectedCount,
taskStateSnapshots.stream()
.map(extractFun)
.mapToInt(
x -> {
int len = x.getOldSubtaskIndexes(0).length;
// Assert that there is a mapping.
Assert.assertTrue(len > 0);
return len;
})
.sum());
}

@Test
public void testStateWithFullyFinishedOperators() throws JobException, JobExecutionException {
List<OperatorID> operatorIds = buildOperatorIds(2);
Expand Down Expand Up @@ -949,15 +1077,50 @@ private Map<OperatorID, OperatorState> buildOperatorStates(
}));
}

private static class OperatorIdWithParallelism {
private final OperatorID operatorID;
private final int parallelism;

public OperatorID getOperatorID() {
return operatorID;
}

public int getParallelism() {
return parallelism;
}

public OperatorIdWithParallelism(OperatorID operatorID, int parallelism) {
this.operatorID = operatorID;
this.parallelism = parallelism;
}
}

private Map<OperatorID, ExecutionJobVertex> buildVertices(
List<OperatorID> operatorIds,
int parallelism,
int parallelisms,
SubtaskStateMapper downstreamRescaler,
SubtaskStateMapper upstreamRescaler)
throws JobException, JobExecutionException {
final JobVertex[] jobVertices =
List<OperatorIdWithParallelism> opIdsWithParallelism =
operatorIds.stream()
.map(id -> createJobVertex(id, id, parallelism))
.map(operatorID -> new OperatorIdWithParallelism(operatorID, parallelisms))
.collect(Collectors.toList());
return buildVertices(opIdsWithParallelism, downstreamRescaler, upstreamRescaler);
}

private Map<OperatorID, ExecutionJobVertex> buildVertices(
List<OperatorIdWithParallelism> operatorIdsAndParallelism,
SubtaskStateMapper downstreamRescaler,
SubtaskStateMapper upstreamRescaler)
throws JobException, JobExecutionException {
final JobVertex[] jobVertices =
operatorIdsAndParallelism.stream()
.map(
idWithParallelism ->
createJobVertex(
idWithParallelism.getOperatorID(),
idWithParallelism.getOperatorID(),
idWithParallelism.getParallelism()))
.toArray(JobVertex[]::new);
for (int index = 1; index < jobVertices.length; index++) {
connectVertices(
Expand Down Expand Up @@ -1029,6 +1192,15 @@ private JobVertex createJobVertex(
return jobVertex;
}

private List<TaskStateSnapshot> getTaskStateSnapshotFromVertex(
ExecutionJobVertex executionJobVertex) {
return Arrays.stream(executionJobVertex.getTaskVertices())
.map(ExecutionVertex::getCurrentExecutionAttempt)
.map(Execution::getTaskRestore)
.map(JobManagerTaskRestore::getTaskStateSnapshot)
.collect(Collectors.toList());
}

private OperatorSubtaskState getAssignedState(
ExecutionJobVertex executionJobVertex, OperatorID operatorId, int subtaskIdx) {
return executionJobVertex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,17 @@ public void create(
StreamExecutionEnvironment env,
int minCheckpoints,
boolean slotSharing,
int expectedRestarts) {
int expectedRestarts,
long sourceSleepMs) {
final int parallelism = env.getParallelism();
final SingleOutputStreamOperator<Long> stream =
env.fromSource(
new LongSource(
minCheckpoints,
parallelism,
expectedRestarts,
env.getCheckpointInterval()),
env.getCheckpointInterval(),
sourceSleepMs),
noWatermarks(),
"source")
.slotSharingGroup(slotSharing ? "default" : "source")
Expand All @@ -144,7 +146,8 @@ public void create(
StreamExecutionEnvironment env,
int minCheckpoints,
boolean slotSharing,
int expectedRestarts) {
int expectedRestarts,
long sourceSleepMs) {
final int parallelism = env.getParallelism();
DataStream<Long> combinedSource = null;
for (int inputIndex = 0; inputIndex < NUM_SOURCES; inputIndex++) {
Expand All @@ -154,7 +157,8 @@ public void create(
minCheckpoints,
parallelism,
expectedRestarts,
env.getCheckpointInterval()),
env.getCheckpointInterval(),
sourceSleepMs),
noWatermarks(),
"source" + inputIndex)
.slotSharingGroup(
Expand Down Expand Up @@ -182,7 +186,8 @@ public void create(
StreamExecutionEnvironment env,
int minCheckpoints,
boolean slotSharing,
int expectedRestarts) {
int expectedRestarts,
long sourceSleepMs) {
final int parallelism = env.getParallelism();
DataStream<Tuple2<Integer, Long>> combinedSource = null;
for (int inputIndex = 0; inputIndex < NUM_SOURCES; inputIndex++) {
Expand All @@ -193,7 +198,8 @@ public void create(
minCheckpoints,
parallelism,
expectedRestarts,
env.getCheckpointInterval()),
env.getCheckpointInterval(),
sourceSleepMs),
noWatermarks(),
"source" + inputIndex)
.slotSharingGroup(
Expand Down
Loading

0 comments on commit 8d8a486

Please sign in to comment.