Skip to content

Commit

Permalink
[#825][FOLLOWUP] fix(spark): Apply a thread safety way to track the b…
Browse files Browse the repository at this point in the history
…locks sending result (#1260)

### What changes were proposed in this pull request?
As title

### Why are the changes needed?
```
[INFO] Running org.apache.uniffle.test.ContinuousSelectPartitionStrategyTest
Error:  Tests run: 1, Failures: 0, Errors: 1, Skipped: 0, Time elapsed: 59.195 s <<< FAILURE! - in org.apache.uniffle.test.ContinuousSelectPartitionStrategyTest
Error:  resultCompareTest  Time elapsed: 55.751 s  <<< ERROR!
org.apache.spark.SparkException: 
Job aborted due to stage failure: Task 6 in stage 1.0 failed 1 times, most recent failure: Lost task 6.0 in stage 1.0 (TID 16) (fv-az391-410.nf14wd45lyte3l5gjbhk121dmd.jx.internal.cloudapp.net executor driver): org.apache.uniffle.common.exception.RssException: Timeout: Task[16_0] failed because 9 blocks can't be sent to shuffle server in 30000 ms.
	at org.apache.spark.shuffle.writer.RssShuffleWriter.checkBlockSendResult(RssShuffleWriter.java:350)
	at org.apache.spark.shuffle.writer.RssShuffleWriter.writeImpl(RssShuffleWriter.java:246)
	at org.apache.spark.shuffle.writer.RssShuffleWriter.write(RssShuffleWriter.java:209)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
```
[ActionLink](https://github.com/apache/incubator-uniffle/actions/runs/6611324517/job/17954967498?pr=1257)
I debug the local test and find that all blocks are successfully send, but some blocks are not in the block tracker

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Integration test
Especially run below test in a loop of many times without a fail
```
mvn -B -fae test -Dtest=org.apache.uniffle.test.ContinuousSelectPartitionStrategyTest -Pspark3
```
  • Loading branch information
summaryzb authored Oct 25, 2023
1 parent 3071099 commit ec7f85c
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
Expand Down Expand Up @@ -56,7 +57,8 @@ public class DataPusher implements Closeable {
private final Map<String, Set<Long>> taskToSuccessBlockIds;
// Must be thread safe
private final Map<String, Set<Long>> taskToFailedBlockIds;
private final Map<String, Map<Long, List<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer;
private final Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>>
taskToFailedBlockIdsAndServer;
private String rssAppId;
// Must be thread safe
private final Set<String> failedTaskIds;
Expand All @@ -65,7 +67,7 @@ public DataPusher(
ShuffleWriteClient shuffleWriteClient,
Map<String, Set<Long>> taskToSuccessBlockIds,
Map<String, Set<Long>> taskToFailedBlockIds,
Map<String, Map<Long, List<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer,
Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer,
Set<String> failedTaskIds,
int threadPoolSize,
int threadKeepAliveTime) {
Expand Down Expand Up @@ -126,9 +128,9 @@ private synchronized void putBlockId(
}

private synchronized void putSendFailedBlockIdAndShuffleServer(
Map<String, Map<Long, List<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer,
Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer,
String taskAttemptId,
Map<Long, List<ShuffleServerInfo>> blockIdsAndServer) {
Map<Long, BlockingQueue<ShuffleServerInfo>> blockIdsAndServer) {
if (blockIdsAndServer == null || blockIdsAndServer.isEmpty()) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.function.Supplier;
Expand Down Expand Up @@ -81,7 +82,7 @@ public void testSendData() throws ExecutionException, InterruptedException {

Map<String, Set<Long>> taskToSuccessBlockIds = Maps.newConcurrentMap();
Map<String, Set<Long>> taskToFailedBlockIds = Maps.newConcurrentMap();
Map<String, Map<Long, List<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer =
Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer =
JavaUtils.newConcurrentMap();
Set<String> failedTaskIds = new HashSet<>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -87,8 +88,8 @@ public class RssShuffleManager extends RssShuffleManagerBase {
private Map<String, Set<Long>> taskToSuccessBlockIds = JavaUtils.newConcurrentMap();
private Map<String, Set<Long>> taskToFailedBlockIds = JavaUtils.newConcurrentMap();
// Record both the block that failed to be sent and the ShuffleServer
private final Map<String, Map<Long, List<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer =
JavaUtils.newConcurrentMap();
private final Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>>
taskToFailedBlockIdsAndServer = JavaUtils.newConcurrentMap();
private final int dataReplica;
private final int dataReplicaWrite;
private final int dataReplicaRead;
Expand Down Expand Up @@ -703,10 +704,11 @@ private Roaring64NavigableMap getShuffleResult(
* @param taskId Shuffle taskId
* @return List of failed ShuffleServer blocks
*/
public Map<Long, List<ShuffleServerInfo>> getFailedBlockIdsWithShuffleServer(String taskId) {
Map<Long, List<ShuffleServerInfo>> result = taskToFailedBlockIdsAndServer.get(taskId);
public Map<Long, BlockingQueue<ShuffleServerInfo>> getFailedBlockIdsWithShuffleServer(
String taskId) {
Map<Long, BlockingQueue<ShuffleServerInfo>> result = taskToFailedBlockIdsAndServer.get(taskId);
if (result == null) {
result = JavaUtils.newConcurrentMap();
result = Collections.emptyMap();
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -162,7 +163,7 @@ private FakedDataPusher(
ShuffleWriteClient shuffleWriteClient,
Map<String, Set<Long>> taskToSuccessBlockIds,
Map<String, Set<Long>> taskToFailedBlockIds,
Map<String, Map<Long, List<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer,
Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer,
Set<String> failedTaskIds,
int threadPoolSize,
int threadKeepAliveTime,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -98,7 +99,8 @@ public class RssShuffleManager extends RssShuffleManagerBase {
private final Map<String, Set<Long>> taskToSuccessBlockIds;
private final Map<String, Set<Long>> taskToFailedBlockIds;
// Record both the block that failed to be sent and the ShuffleServer
private final Map<String, Map<Long, List<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer;
private final Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>>
taskToFailedBlockIdsAndServer;
private ScheduledExecutorService heartBeatScheduledExecutorService;
private boolean heartbeatStarted = false;
private boolean dynamicConfEnabled = false;
Expand Down Expand Up @@ -270,7 +272,7 @@ protected static ShuffleDataDistributionType getDataDistributionType(SparkConf s
DataPusher dataPusher,
Map<String, Set<Long>> taskToSuccessBlockIds,
Map<String, Set<Long>> taskToFailedBlockIds,
Map<String, Map<Long, List<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer) {
Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer) {
this.sparkConf = conf;
this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
this.dataDistributionType =
Expand Down Expand Up @@ -999,12 +1001,13 @@ private Roaring64NavigableMap getShuffleResultForMultiPart(
* The ShuffleServer list of block sending failures is returned using the shuffle task ID
*
* @param taskId Shuffle taskId
* @return List of failed ShuffleServer blocks
* @return failed ShuffleServer blocks
*/
public Map<Long, List<ShuffleServerInfo>> getFailedBlockIdsWithShuffleServer(String taskId) {
Map<Long, List<ShuffleServerInfo>> result = taskToFailedBlockIdsAndServer.get(taskId);
public Map<Long, BlockingQueue<ShuffleServerInfo>> getFailedBlockIdsWithShuffleServer(
String taskId) {
Map<Long, BlockingQueue<ShuffleServerInfo>> result = taskToFailedBlockIdsAndServer.get(taskId);
if (result == null) {
result = JavaUtils.newConcurrentMap();
result = Collections.emptyMap();
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

package org.apache.spark.shuffle;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;

import org.apache.commons.lang3.SystemUtils;
import org.apache.spark.SparkConf;
Expand All @@ -37,7 +37,7 @@ public static RssShuffleManager createShuffleManager(
DataPusher dataPusher,
Map<String, Set<Long>> successBlockIds,
Map<String, Set<Long>> failBlockIds,
Map<String, Map<Long, List<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer) {
Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer) {
return new RssShuffleManager(
conf, isDriver, dataPusher, successBlockIds, failBlockIds, taskToFailedBlockIdsAndServer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -82,12 +83,10 @@ public void checkBlockSendResultTest() {
.set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "127.0.0.1:12345,127.0.0.1:12346");
Map<String, Set<Long>> failBlocks = JavaUtils.newConcurrentMap();
Map<String, Set<Long>> successBlocks = JavaUtils.newConcurrentMap();
Map<String, Map<Long, List<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer =
JavaUtils.newConcurrentMap();
Serializer kryoSerializer = new KryoSerializer(conf);
RssShuffleManager manager =
TestUtils.createShuffleManager(
conf, false, null, successBlocks, failBlocks, taskToFailedBlockIdsAndServer);
conf, false, null, successBlocks, failBlocks, JavaUtils.newConcurrentMap());

ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class);
Partitioner mockPartitioner = mock(Partitioner.class);
Expand Down Expand Up @@ -164,7 +163,7 @@ private FakedDataPusher(
ShuffleWriteClient shuffleWriteClient,
Map<String, Set<Long>> taskToSuccessBlockIds,
Map<String, Set<Long>> taskToFailedBlockIds,
Map<String, Map<Long, List<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer,
Map<String, Map<Long, BlockingQueue<ShuffleServerInfo>>> taskToFailedBlockIdsAndServer,
Set<String> failedTaskIds,
int threadPoolSize,
int threadKeepAliveTime,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
Expand Down Expand Up @@ -151,8 +153,8 @@ private boolean sendShuffleDataAsync(
String appId,
Map<ShuffleServerInfo, Map<Integer, Map<Integer, List<ShuffleBlockInfo>>>> serverToBlocks,
Map<ShuffleServerInfo, List<Long>> serverToBlockIds,
Map<Long, List<ShuffleServerInfo>> blockIdsSendSuccessTracker,
Map<Long, List<ShuffleServerInfo>> blockIdsSendFailTracker,
Map<Long, AtomicInteger> blockIdsSendSuccessTracker,
Map<Long, BlockingQueue<ShuffleServerInfo>> blockIdsSendFailTracker,
boolean allowFastFail,
Supplier<Boolean> needCancelRequest) {

Expand Down Expand Up @@ -195,10 +197,7 @@ private boolean sendShuffleDataAsync(
serverToBlockIds
.get(ssi)
.forEach(
blockId ->
blockIdsSendSuccessTracker
.computeIfAbsent(blockId, id -> Lists.newArrayList())
.add(ssi));
blockId -> blockIdsSendSuccessTracker.get(blockId).incrementAndGet());
if (defectiveServers != null) {
defectiveServers.remove(ssi);
}
Expand All @@ -211,7 +210,7 @@ private boolean sendShuffleDataAsync(
.forEach(
blockId ->
blockIdsSendFailTracker
.computeIfAbsent(blockId, id -> Lists.newArrayList())
.computeIfAbsent(blockId, id -> new LinkedBlockingQueue<>())
.add(ssi));
if (defectiveServers != null) {
defectiveServers.add(ssi);
Expand All @@ -225,7 +224,7 @@ private boolean sendShuffleDataAsync(
.forEach(
blockId ->
blockIdsSendFailTracker
.computeIfAbsent(blockId, id -> Lists.newArrayList())
.computeIfAbsent(blockId, id -> new LinkedBlockingQueue<>())
.add(ssi));
if (defectiveServers != null) {
defectiveServers.add(ssi);
Expand Down Expand Up @@ -355,16 +354,36 @@ public SendShuffleDataResult sendShuffleData(
}
}
/** Records the ShuffleServer that successfully or failed to send blocks */
Map<Long, List<ShuffleServerInfo>> blockIdSendSuccessTracker = JavaUtils.newConcurrentMap();
Map<Long, List<ShuffleServerInfo>> blockIdsSendFailTracker = JavaUtils.newConcurrentMap();
// we assume that most of the blocks can be sent successfully
// so initialize the map at first without concurrency insurance
// AtomicInteger is enough to reflect value changes in other threads
Map<Long, AtomicInteger> blockIdsSendSuccessTracker = Maps.newHashMap();
primaryServerToBlockIds
.values()
.forEach(
blockList ->
blockList.forEach(
block ->
blockIdsSendSuccessTracker.computeIfAbsent(
block, id -> new AtomicInteger(0))));
secondaryServerToBlockIds
.values()
.forEach(
blockList ->
blockList.forEach(
block ->
blockIdsSendSuccessTracker.computeIfAbsent(
block, id -> new AtomicInteger(0))));
Map<Long, BlockingQueue<ShuffleServerInfo>> blockIdsSendFailTracker =
JavaUtils.newConcurrentMap();

// sent the primary round of blocks.
boolean isAllSuccess =
sendShuffleDataAsync(
appId,
primaryServerToBlocks,
primaryServerToBlockIds,
blockIdSendSuccessTracker,
blockIdsSendSuccessTracker,
blockIdsSendFailTracker,
secondaryServerToBlocks.isEmpty(),
needCancelRequest);
Expand All @@ -380,20 +399,19 @@ public SendShuffleDataResult sendShuffleData(
appId,
secondaryServerToBlocks,
secondaryServerToBlockIds,
blockIdSendSuccessTracker,
blockIdsSendSuccessTracker,
blockIdsSendFailTracker,
true,
needCancelRequest);
}

blockIdSendSuccessTracker
Set<Long> blockIdsSendSuccessSet = Sets.newHashSet();
blockIdsSendSuccessTracker
.entrySet()
.forEach(
successBlockId -> {
if (successBlockId.getValue().size() < replicaWrite) {
// Removes blocks that do not reach replicaWrite from the success queue
blockIdSendSuccessTracker.remove(successBlockId.getKey());
} else {
if (successBlockId.getValue().get() >= replicaWrite) {
blockIdsSendSuccessSet.add(successBlockId.getKey());
// If the replicaWrite to be sent is reached,
// no matter whether the block fails to be sent or not,
// the block is considered to have been sent successfully and is removed from the
Expand All @@ -402,9 +420,7 @@ public SendShuffleDataResult sendShuffleData(
}
});
return new SendShuffleDataResult(
blockIdSendSuccessTracker.keySet(),
blockIdsSendFailTracker.keySet(),
blockIdsSendFailTracker);
blockIdsSendSuccessSet, blockIdsSendFailTracker.keySet(), blockIdsSendFailTracker);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

package org.apache.uniffle.client.response;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;

import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.util.JavaUtils;
Expand All @@ -28,7 +28,7 @@ public class SendShuffleDataResult {

private Set<Long> successBlockIds;
private Set<Long> failedBlockIds;
private Map<Long, List<ShuffleServerInfo>> sendFailedBlockIds;
private Map<Long, BlockingQueue<ShuffleServerInfo>> sendFailedBlockIds;

public SendShuffleDataResult(Set<Long> successBlockIds, Set<Long> failedBlockIds) {
this.successBlockIds = successBlockIds;
Expand All @@ -39,7 +39,7 @@ public SendShuffleDataResult(Set<Long> successBlockIds, Set<Long> failedBlockIds
public SendShuffleDataResult(
Set<Long> successBlockIds,
Set<Long> failedBlockIds,
Map<Long, List<ShuffleServerInfo>> sendFailedBlockIds) {
Map<Long, BlockingQueue<ShuffleServerInfo>> sendFailedBlockIds) {
this.successBlockIds = successBlockIds;
this.failedBlockIds = failedBlockIds;
this.sendFailedBlockIds = sendFailedBlockIds;
Expand All @@ -53,7 +53,7 @@ public Set<Long> getFailedBlockIds() {
return failedBlockIds;
}

public Map<Long, List<ShuffleServerInfo>> getSendFailedBlockIds() {
public Map<Long, BlockingQueue<ShuffleServerInfo>> getSendFailedBlockIds() {
return sendFailedBlockIds;
}
}

0 comments on commit ec7f85c

Please sign in to comment.