From ec7f85c3d0d56286b46bcfcb50f1de75ba1d736d Mon Sep 17 00:00:00 2001 From: summaryzb Date: Wed, 25 Oct 2023 11:03:26 +0800 Subject: [PATCH] [#825][FOLLOWUP] fix(spark): Apply a thread safety way to track the blocks sending result (#1260) ### What changes were proposed in this pull request? As title ### Why are the changes needed? ``` [INFO] Running org.apache.uniffle.test.ContinuousSelectPartitionStrategyTest Error: Tests run: 1, Failures: 0, Errors: 1, Skipped: 0, Time elapsed: 59.195 s <<< FAILURE! - in org.apache.uniffle.test.ContinuousSelectPartitionStrategyTest Error: resultCompareTest Time elapsed: 55.751 s <<< ERROR! org.apache.spark.SparkException: Job aborted due to stage failure: Task 6 in stage 1.0 failed 1 times, most recent failure: Lost task 6.0 in stage 1.0 (TID 16) (fv-az391-410.nf14wd45lyte3l5gjbhk121dmd.jx.internal.cloudapp.net executor driver): org.apache.uniffle.common.exception.RssException: Timeout: Task[16_0] failed because 9 blocks can't be sent to shuffle server in 30000 ms. at org.apache.spark.shuffle.writer.RssShuffleWriter.checkBlockSendResult(RssShuffleWriter.java:350) at org.apache.spark.shuffle.writer.RssShuffleWriter.writeImpl(RssShuffleWriter.java:246) at org.apache.spark.shuffle.writer.RssShuffleWriter.write(RssShuffleWriter.java:209) at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52) ``` [ActionLink](https://github.com/apache/incubator-uniffle/actions/runs/6611324517/job/17954967498?pr=1257) I debug the local test and find that all blocks are successfully send, but some blocks are not in the block tracker ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Integration test Especially run below test in a loop of many times without a fail ``` mvn -B -fae test -Dtest=org.apache.uniffle.test.ContinuousSelectPartitionStrategyTest -Pspark3 ``` --- .../spark/shuffle/writer/DataPusher.java | 10 ++-- .../spark/shuffle/writer/DataPusherTest.java | 3 +- .../spark/shuffle/RssShuffleManager.java | 12 ++-- .../shuffle/writer/RssShuffleWriterTest.java | 3 +- .../spark/shuffle/RssShuffleManager.java | 15 +++-- .../org/apache/spark/shuffle/TestUtils.java | 4 +- .../shuffle/writer/RssShuffleWriterTest.java | 7 +-- .../client/impl/ShuffleWriteClientImpl.java | 56 ++++++++++++------- .../response/SendShuffleDataResult.java | 8 +-- 9 files changed, 71 insertions(+), 47 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java index b578ac0174..68ec8fb3b8 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadPoolExecutor; @@ -56,7 +57,8 @@ public class DataPusher implements Closeable { private final Map> taskToSuccessBlockIds; // Must be thread safe private final Map> taskToFailedBlockIds; - private final Map>> taskToFailedBlockIdsAndServer; + private final Map>> + taskToFailedBlockIdsAndServer; private String rssAppId; // Must be thread safe private final Set failedTaskIds; @@ -65,7 +67,7 @@ public DataPusher( ShuffleWriteClient shuffleWriteClient, Map> taskToSuccessBlockIds, Map> taskToFailedBlockIds, - Map>> taskToFailedBlockIdsAndServer, + Map>> taskToFailedBlockIdsAndServer, Set failedTaskIds, int threadPoolSize, int threadKeepAliveTime) { @@ -126,9 +128,9 @@ private synchronized void putBlockId( } private synchronized void putSendFailedBlockIdAndShuffleServer( - Map>> taskToFailedBlockIdsAndServer, + Map>> taskToFailedBlockIdsAndServer, String taskAttemptId, - Map> blockIdsAndServer) { + Map> blockIdsAndServer) { if (blockIdsAndServer == null || blockIdsAndServer.isEmpty()) { return; } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java index 21a8499dba..a3cdbb6e59 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.function.Supplier; @@ -81,7 +82,7 @@ public void testSendData() throws ExecutionException, InterruptedException { Map> taskToSuccessBlockIds = Maps.newConcurrentMap(); Map> taskToFailedBlockIds = Maps.newConcurrentMap(); - Map>> taskToFailedBlockIdsAndServer = + Map>> taskToFailedBlockIdsAndServer = JavaUtils.newConcurrentMap(); Set failedTaskIds = new HashSet<>(); diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index d1f1d9a6c6..4efde2ec50 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -87,8 +88,8 @@ public class RssShuffleManager extends RssShuffleManagerBase { private Map> taskToSuccessBlockIds = JavaUtils.newConcurrentMap(); private Map> taskToFailedBlockIds = JavaUtils.newConcurrentMap(); // Record both the block that failed to be sent and the ShuffleServer - private final Map>> taskToFailedBlockIdsAndServer = - JavaUtils.newConcurrentMap(); + private final Map>> + taskToFailedBlockIdsAndServer = JavaUtils.newConcurrentMap(); private final int dataReplica; private final int dataReplicaWrite; private final int dataReplicaRead; @@ -703,10 +704,11 @@ private Roaring64NavigableMap getShuffleResult( * @param taskId Shuffle taskId * @return List of failed ShuffleServer blocks */ - public Map> getFailedBlockIdsWithShuffleServer(String taskId) { - Map> result = taskToFailedBlockIdsAndServer.get(taskId); + public Map> getFailedBlockIdsWithShuffleServer( + String taskId) { + Map> result = taskToFailedBlockIdsAndServer.get(taskId); if (result == null) { - result = JavaUtils.newConcurrentMap(); + result = Collections.emptyMap(); } return result; } diff --git a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java index 86115c699f..e60930d8cf 100644 --- a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java +++ b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.function.Function; import java.util.stream.Collectors; @@ -162,7 +163,7 @@ private FakedDataPusher( ShuffleWriteClient shuffleWriteClient, Map> taskToSuccessBlockIds, Map> taskToFailedBlockIds, - Map>> taskToFailedBlockIdsAndServer, + Map>> taskToFailedBlockIdsAndServer, Set failedTaskIds, int threadPoolSize, int threadKeepAliveTime, diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index a917b456b8..9ec5e90ae3 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -98,7 +99,8 @@ public class RssShuffleManager extends RssShuffleManagerBase { private final Map> taskToSuccessBlockIds; private final Map> taskToFailedBlockIds; // Record both the block that failed to be sent and the ShuffleServer - private final Map>> taskToFailedBlockIdsAndServer; + private final Map>> + taskToFailedBlockIdsAndServer; private ScheduledExecutorService heartBeatScheduledExecutorService; private boolean heartbeatStarted = false; private boolean dynamicConfEnabled = false; @@ -270,7 +272,7 @@ protected static ShuffleDataDistributionType getDataDistributionType(SparkConf s DataPusher dataPusher, Map> taskToSuccessBlockIds, Map> taskToFailedBlockIds, - Map>> taskToFailedBlockIdsAndServer) { + Map>> taskToFailedBlockIdsAndServer) { this.sparkConf = conf; this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE); this.dataDistributionType = @@ -999,12 +1001,13 @@ private Roaring64NavigableMap getShuffleResultForMultiPart( * The ShuffleServer list of block sending failures is returned using the shuffle task ID * * @param taskId Shuffle taskId - * @return List of failed ShuffleServer blocks + * @return failed ShuffleServer blocks */ - public Map> getFailedBlockIdsWithShuffleServer(String taskId) { - Map> result = taskToFailedBlockIdsAndServer.get(taskId); + public Map> getFailedBlockIdsWithShuffleServer( + String taskId) { + Map> result = taskToFailedBlockIdsAndServer.get(taskId); if (result == null) { - result = JavaUtils.newConcurrentMap(); + result = Collections.emptyMap(); } return result; } diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java index d0424f4f2f..2312424751 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java @@ -17,9 +17,9 @@ package org.apache.spark.shuffle; -import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.BlockingQueue; import org.apache.commons.lang3.SystemUtils; import org.apache.spark.SparkConf; @@ -37,7 +37,7 @@ public static RssShuffleManager createShuffleManager( DataPusher dataPusher, Map> successBlockIds, Map> failBlockIds, - Map>> taskToFailedBlockIdsAndServer) { + Map>> taskToFailedBlockIdsAndServer) { return new RssShuffleManager( conf, isDriver, dataPusher, successBlockIds, failBlockIds, taskToFailedBlockIdsAndServer); } diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java index 09eeff7c42..ca99362888 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.function.Function; import java.util.stream.Collectors; @@ -82,12 +83,10 @@ public void checkBlockSendResultTest() { .set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "127.0.0.1:12345,127.0.0.1:12346"); Map> failBlocks = JavaUtils.newConcurrentMap(); Map> successBlocks = JavaUtils.newConcurrentMap(); - Map>> taskToFailedBlockIdsAndServer = - JavaUtils.newConcurrentMap(); Serializer kryoSerializer = new KryoSerializer(conf); RssShuffleManager manager = TestUtils.createShuffleManager( - conf, false, null, successBlocks, failBlocks, taskToFailedBlockIdsAndServer); + conf, false, null, successBlocks, failBlocks, JavaUtils.newConcurrentMap()); ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class); Partitioner mockPartitioner = mock(Partitioner.class); @@ -164,7 +163,7 @@ private FakedDataPusher( ShuffleWriteClient shuffleWriteClient, Map> taskToSuccessBlockIds, Map> taskToFailedBlockIds, - Map>> taskToFailedBlockIdsAndServer, + Map>> taskToFailedBlockIdsAndServer, Set failedTaskIds, int threadPoolSize, int threadKeepAliveTime, diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java index f777f6ff85..ff7aad7194 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java @@ -23,12 +23,14 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; @@ -151,8 +153,8 @@ private boolean sendShuffleDataAsync( String appId, Map>>> serverToBlocks, Map> serverToBlockIds, - Map> blockIdsSendSuccessTracker, - Map> blockIdsSendFailTracker, + Map blockIdsSendSuccessTracker, + Map> blockIdsSendFailTracker, boolean allowFastFail, Supplier needCancelRequest) { @@ -195,10 +197,7 @@ private boolean sendShuffleDataAsync( serverToBlockIds .get(ssi) .forEach( - blockId -> - blockIdsSendSuccessTracker - .computeIfAbsent(blockId, id -> Lists.newArrayList()) - .add(ssi)); + blockId -> blockIdsSendSuccessTracker.get(blockId).incrementAndGet()); if (defectiveServers != null) { defectiveServers.remove(ssi); } @@ -211,7 +210,7 @@ private boolean sendShuffleDataAsync( .forEach( blockId -> blockIdsSendFailTracker - .computeIfAbsent(blockId, id -> Lists.newArrayList()) + .computeIfAbsent(blockId, id -> new LinkedBlockingQueue<>()) .add(ssi)); if (defectiveServers != null) { defectiveServers.add(ssi); @@ -225,7 +224,7 @@ private boolean sendShuffleDataAsync( .forEach( blockId -> blockIdsSendFailTracker - .computeIfAbsent(blockId, id -> Lists.newArrayList()) + .computeIfAbsent(blockId, id -> new LinkedBlockingQueue<>()) .add(ssi)); if (defectiveServers != null) { defectiveServers.add(ssi); @@ -355,8 +354,28 @@ public SendShuffleDataResult sendShuffleData( } } /** Records the ShuffleServer that successfully or failed to send blocks */ - Map> blockIdSendSuccessTracker = JavaUtils.newConcurrentMap(); - Map> blockIdsSendFailTracker = JavaUtils.newConcurrentMap(); + // we assume that most of the blocks can be sent successfully + // so initialize the map at first without concurrency insurance + // AtomicInteger is enough to reflect value changes in other threads + Map blockIdsSendSuccessTracker = Maps.newHashMap(); + primaryServerToBlockIds + .values() + .forEach( + blockList -> + blockList.forEach( + block -> + blockIdsSendSuccessTracker.computeIfAbsent( + block, id -> new AtomicInteger(0)))); + secondaryServerToBlockIds + .values() + .forEach( + blockList -> + blockList.forEach( + block -> + blockIdsSendSuccessTracker.computeIfAbsent( + block, id -> new AtomicInteger(0)))); + Map> blockIdsSendFailTracker = + JavaUtils.newConcurrentMap(); // sent the primary round of blocks. boolean isAllSuccess = @@ -364,7 +383,7 @@ public SendShuffleDataResult sendShuffleData( appId, primaryServerToBlocks, primaryServerToBlockIds, - blockIdSendSuccessTracker, + blockIdsSendSuccessTracker, blockIdsSendFailTracker, secondaryServerToBlocks.isEmpty(), needCancelRequest); @@ -380,20 +399,19 @@ public SendShuffleDataResult sendShuffleData( appId, secondaryServerToBlocks, secondaryServerToBlockIds, - blockIdSendSuccessTracker, + blockIdsSendSuccessTracker, blockIdsSendFailTracker, true, needCancelRequest); } - blockIdSendSuccessTracker + Set blockIdsSendSuccessSet = Sets.newHashSet(); + blockIdsSendSuccessTracker .entrySet() .forEach( successBlockId -> { - if (successBlockId.getValue().size() < replicaWrite) { - // Removes blocks that do not reach replicaWrite from the success queue - blockIdSendSuccessTracker.remove(successBlockId.getKey()); - } else { + if (successBlockId.getValue().get() >= replicaWrite) { + blockIdsSendSuccessSet.add(successBlockId.getKey()); // If the replicaWrite to be sent is reached, // no matter whether the block fails to be sent or not, // the block is considered to have been sent successfully and is removed from the @@ -402,9 +420,7 @@ public SendShuffleDataResult sendShuffleData( } }); return new SendShuffleDataResult( - blockIdSendSuccessTracker.keySet(), - blockIdsSendFailTracker.keySet(), - blockIdsSendFailTracker); + blockIdsSendSuccessSet, blockIdsSendFailTracker.keySet(), blockIdsSendFailTracker); } /** diff --git a/client/src/main/java/org/apache/uniffle/client/response/SendShuffleDataResult.java b/client/src/main/java/org/apache/uniffle/client/response/SendShuffleDataResult.java index d33a298621..f2d820e60f 100644 --- a/client/src/main/java/org/apache/uniffle/client/response/SendShuffleDataResult.java +++ b/client/src/main/java/org/apache/uniffle/client/response/SendShuffleDataResult.java @@ -17,9 +17,9 @@ package org.apache.uniffle.client.response; -import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.BlockingQueue; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.util.JavaUtils; @@ -28,7 +28,7 @@ public class SendShuffleDataResult { private Set successBlockIds; private Set failedBlockIds; - private Map> sendFailedBlockIds; + private Map> sendFailedBlockIds; public SendShuffleDataResult(Set successBlockIds, Set failedBlockIds) { this.successBlockIds = successBlockIds; @@ -39,7 +39,7 @@ public SendShuffleDataResult(Set successBlockIds, Set failedBlockIds public SendShuffleDataResult( Set successBlockIds, Set failedBlockIds, - Map> sendFailedBlockIds) { + Map> sendFailedBlockIds) { this.successBlockIds = successBlockIds; this.failedBlockIds = failedBlockIds; this.sendFailedBlockIds = sendFailedBlockIds; @@ -53,7 +53,7 @@ public Set getFailedBlockIds() { return failedBlockIds; } - public Map> getSendFailedBlockIds() { + public Map> getSendFailedBlockIds() { return sendFailedBlockIds; } }