Skip to content

Commit

Permalink
fix partitionRange when reassign
Browse files Browse the repository at this point in the history
  • Loading branch information
zuston committed Apr 8, 2024
1 parent b0b68fa commit 761dedf
Showing 1 changed file with 65 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
package org.apache.spark.shuffle;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -27,6 +30,7 @@
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;

import scala.Tuple2;
Expand Down Expand Up @@ -1211,26 +1215,56 @@ public ShuffleServerInfo reassignFaultyShuffleServerForTasks(
}

// get the newer server to replace faulty server.
ShuffleServerInfo newAssignedServer = assignShuffleServer(shuffleId, faultyShuffleServerId);
ShuffleServerInfo newAssignedServer =
reassignShuffleServerForTask(shuffleId, partitionIds, faultyShuffleServerId);
if (newAssignedServer != null) {
handleInfo.createNewReassignmentForMultiPartitions(
partitionIds, faultyShuffleServerId, newAssignedServer);
}
LOG.info("Reassign shuffle-server from {} -> {} for shuffleId: {}, partitionIds: {}",
faultyShuffleServerId, newAssignedServer, shuffleId, partitionIds);
LOG.info(
"Reassign shuffle-server from {} -> {} for shuffleId: {}, partitionIds: {}",
faultyShuffleServerId,
newAssignedServer,
shuffleId,
partitionIds);
return newAssignedServer;
}
}

private ShuffleServerInfo assignShuffleServer(int shuffleId, String faultyShuffleServerId) {
private ShuffleServerInfo reassignShuffleServerForTask(
int shuffleId, Set<Integer> partitionIds, String faultyShuffleServerId) {
Set<String> faultyServerIds = Sets.newHashSet(faultyShuffleServerId);
faultyServerIds.addAll(failuresShuffleServerIds);
AtomicReference<ShuffleServerInfo> replacementRef = new AtomicReference<>();
Map<Integer, List<ShuffleServerInfo>> partitionToServers =
requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds);
if (partitionToServers.get(0) != null && partitionToServers.get(0).size() == 1) {
return partitionToServers.get(0).get(0);
}
return null;
requestShuffleAssignment(
shuffleId,
1,
1,
1,
1,
faultyServerIds,
shuffleAssignmentsInfo -> {
if (shuffleAssignmentsInfo == null) {
return null;
}
Optional<List<ShuffleServerInfo>> replacementOpt =
shuffleAssignmentsInfo.getPartitionToServers().values().stream().findFirst();
ShuffleServerInfo replacement = replacementOpt.get().get(0);
replacementRef.set(replacement);

Map<Integer, List<ShuffleServerInfo>> newPartitionToServers = new HashMap<>();
List<PartitionRange> partitionRanges = new ArrayList<>();
for (Integer partitionId : partitionIds) {
newPartitionToServers.put(partitionId, Arrays.asList(replacement));
partitionRanges.add(new PartitionRange(partitionId, partitionId));
}
Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges =
new HashMap<>();
serverToPartitionRanges.put(replacement, partitionRanges);
return new ShuffleAssignmentsInfo(newPartitionToServers, serverToPartitionRanges);
});
return replacementRef.get();
}

private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
Expand All @@ -1239,7 +1273,8 @@ private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
int partitionNumPerRange,
int assignmentShuffleServerNumber,
int estimateTaskConcurrency,
Set<String> faultyServerIds) {
Set<String> faultyServerIds,
Function<ShuffleAssignmentsInfo, ShuffleAssignmentsInfo> reassignmentHandler) {
Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf);
ClientUtils.validateClientType(clientType);
assignmentTags.add(clientType);
Expand All @@ -1259,6 +1294,9 @@ private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
assignmentShuffleServerNumber,
estimateTaskConcurrency,
faultyServerIds);
if (reassignmentHandler != null) {
response = reassignmentHandler.apply(response);
}
registerShuffleServers(
id.get(), shuffleId, response.getServerToPartitionRanges(), getRemoteStorageInfo());
return response.getPartitionToServers();
Expand All @@ -1270,6 +1308,23 @@ private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
}
}

private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
int shuffleId,
int partitionNum,
int partitionNumPerRange,
int assignmentShuffleServerNumber,
int estimateTaskConcurrency,
Set<String> faultyServerIds) {
return requestShuffleAssignment(
shuffleId,
partitionNum,
partitionNumPerRange,
assignmentShuffleServerNumber,
estimateTaskConcurrency,
faultyServerIds,
null);
}

private RemoteStorageInfo getRemoteStorageInfo() {
String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
RemoteStorageInfo defaultRemoteStorage =
Expand Down

0 comments on commit 761dedf

Please sign in to comment.