Skip to content

Commit 4e43819

Browse files
yuhaiyangyuhaiyang
authored andcommitted
[SPARK-34534] Fix blockIds order when use FetchShuffleBlocks to fetch blocks
### What changes were proposed in this pull request? Fix a problems which can lead to data correctness after part blocks retry in `OneForOneBlockFetcher` when use `FetchShuffleBlocks` . ### Why are the changes needed? This is a data correctness bug, It's is no problems when use old protocol to send `OpenBlocks` before fetch chunks in `OneForOneBlockFetcher`; In latest branch, `OpenBlocks` has been replaced to `FetchShuffleBlocks`. Howerver, `FetchShuffleBlocks` read shuffle blocks order is not the same as `blockIds` in `OneForOneBlockFetcher`; the `blockIds` is used to match blockId with shuffle data with index, now it is out of order; It will lead to read wrong block chunk when some blocks fetch failed in `OneForOneBlockFetcher`, it will retry the rest of the blocks in `blockIds` based on the `blockIds`'s order. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Closes apache#31643 from seayoun/yuhaiyang_fix_use_FetchShuffleBlocks_order. Lead-authored-by: yuhaiyang <yuhaiyang@yuhaiyangs-MacBook-Pro.local> Co-authored-by: yuhaiyang <yuhaiyang@172.19.25.126> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent ecf4811 commit 4e43819

File tree

2 files changed

+81
-14
lines changed

2 files changed

+81
-14
lines changed

common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import java.nio.ByteBuffer;
2222
import java.util.ArrayList;
2323
import java.util.Arrays;
24-
import java.util.HashMap;
24+
import java.util.LinkedHashMap;
2525

2626
import com.google.common.primitives.Ints;
2727
import com.google.common.primitives.Longs;
@@ -81,7 +81,6 @@ public OneForOneBlockFetcher(
8181
TransportConf transportConf,
8282
DownloadFileManager downloadFileManager) {
8383
this.client = client;
84-
this.blockIds = blockIds;
8584
this.listener = listener;
8685
this.chunkCallback = new ChunkCallback();
8786
this.transportConf = transportConf;
@@ -90,8 +89,10 @@ public OneForOneBlockFetcher(
9089
throw new IllegalArgumentException("Zero-sized blockIds array");
9190
}
9291
if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
93-
this.message = createFetchShuffleBlocksMsg(appId, execId, blockIds);
92+
this.blockIds = new String[blockIds.length];
93+
this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
9494
} else {
95+
this.blockIds = blockIds;
9596
this.message = new OpenBlocks(appId, execId, blockIds);
9697
}
9798
}
@@ -106,41 +107,53 @@ private boolean isShuffleBlocks(String[] blockIds) {
106107
}
107108

108109
/**
109-
* Analyze the pass in blockIds and create FetchShuffleBlocks message.
110-
* The blockIds has been sorted by mapId and reduceId. It's produced in
111-
* org.apache.spark.MapOutputTracker.convertMapStatuses.
110+
* Create FetchShuffleBlocks message and rebuild internal blockIds by
111+
* analyzing the pass in blockIds.
112112
*/
113-
private FetchShuffleBlocks createFetchShuffleBlocksMsg(
113+
private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
114114
String appId, String execId, String[] blockIds) {
115115
String[] firstBlock = splitBlockId(blockIds[0]);
116116
int shuffleId = Integer.parseInt(firstBlock[1]);
117117
boolean batchFetchEnabled = firstBlock.length == 5;
118118

119-
HashMap<Long, ArrayList<Integer>> mapIdToReduceIds = new HashMap<>();
119+
LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
120120
for (String blockId : blockIds) {
121121
String[] blockIdParts = splitBlockId(blockId);
122122
if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
123123
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
124124
", got:" + blockId);
125125
}
126126
long mapId = Long.parseLong(blockIdParts[2]);
127-
if (!mapIdToReduceIds.containsKey(mapId)) {
128-
mapIdToReduceIds.put(mapId, new ArrayList<>());
127+
if (!mapIdToBlocksInfo.containsKey(mapId)) {
128+
mapIdToBlocksInfo.put(mapId, new BlocksInfo());
129129
}
130-
mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[3]));
130+
BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
131+
blocksInfoByMapId.blockIds.add(blockId);
132+
blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
131133
if (batchFetchEnabled) {
132134
// When we read continuous shuffle blocks in batch, we will reuse reduceIds in
133135
// FetchShuffleBlocks to store the start and end reduce id for range
134136
// [startReduceId, endReduceId).
135137
assert(blockIdParts.length == 5);
136-
mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[4]));
138+
blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
137139
}
138140
}
139-
long[] mapIds = Longs.toArray(mapIdToReduceIds.keySet());
141+
long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
140142
int[][] reduceIdArr = new int[mapIds.length][];
143+
int blockIdIndex = 0;
141144
for (int i = 0; i < mapIds.length; i++) {
142-
reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i]));
145+
BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
146+
reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
147+
148+
// The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks
149+
// because the shuffle data's return order should match the `blockIds`'s order to ensure
150+
// blockId and data match.
151+
for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
152+
this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
153+
}
143154
}
155+
assert(blockIdIndex == this.blockIds.length);
156+
144157
return new FetchShuffleBlocks(
145158
appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
146159
}
@@ -157,6 +170,18 @@ private String[] splitBlockId(String blockId) {
157170
return blockIdParts;
158171
}
159172

173+
/** The reduceIds and blocks in a single mapId */
174+
private class BlocksInfo {
175+
176+
final ArrayList<Integer> reduceIds;
177+
final ArrayList<String> blockIds;
178+
179+
BlocksInfo() {
180+
this.reduceIds = new ArrayList<>();
181+
this.blockIds = new ArrayList<>();
182+
}
183+
}
184+
160185
/** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
161186
private class ChunkCallback implements ChunkReceivedCallback {
162187
@Override

common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,48 @@ public void testEmptyBlockFetch() {
201201
}
202202
}
203203

204+
@Test
205+
public void testFetchShuffleBlocksOrder() {
206+
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
207+
blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[1])));
208+
blocks.put("shuffle_0_2_1", new NioManagedBuffer(ByteBuffer.wrap(new byte[2])));
209+
blocks.put("shuffle_0_10_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[3])));
210+
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
211+
212+
BlockFetchingListener listener = fetchBlocks(
213+
blocks,
214+
blockIds,
215+
new FetchShuffleBlocks("app-id", "exec-id", 0,
216+
new long[]{0, 2, 10}, new int[][]{{0}, {1}, {2}}, false),
217+
conf);
218+
219+
for (int chunkIndex = 0; chunkIndex < blockIds.length; chunkIndex++) {
220+
String blockId = blockIds[chunkIndex];
221+
verify(listener).onBlockFetchSuccess(blockId, blocks.get(blockId));
222+
}
223+
}
224+
225+
@Test
226+
public void testBatchFetchShuffleBlocksOrder() {
227+
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
228+
blocks.put("shuffle_0_0_1_2", new NioManagedBuffer(ByteBuffer.wrap(new byte[1])));
229+
blocks.put("shuffle_0_2_2_3", new NioManagedBuffer(ByteBuffer.wrap(new byte[2])));
230+
blocks.put("shuffle_0_10_3_4", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[3])));
231+
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
232+
233+
BlockFetchingListener listener = fetchBlocks(
234+
blocks,
235+
blockIds,
236+
new FetchShuffleBlocks("app-id", "exec-id", 0,
237+
new long[]{0, 2, 10}, new int[][]{{1, 2}, {2, 3}, {3, 4}}, true),
238+
conf);
239+
240+
for (int chunkIndex = 0; chunkIndex < blockIds.length; chunkIndex++) {
241+
String blockId = blockIds[chunkIndex];
242+
verify(listener).onBlockFetchSuccess(blockId, blocks.get(blockId));
243+
}
244+
}
245+
204246
/**
205247
* Begins a fetch on the given set of blocks by mocking out the server side of the RPC which
206248
* simply returns the given (BlockId, Block) pairs.

0 commit comments

Comments
 (0)