Skip to content

Commit

Permalink
[apache#731] feat(spark): Make blockid layout configurable for Spark …
Browse files Browse the repository at this point in the history
…clients (apache#1528)

### What changes were proposed in this pull request?
Make bit-lengths in block id (block id layout) configurable through dynamic client config from coordinator or client config. Block id layout can be created from RssConf, or where that is not available, is being passed around.
 
- Adds new options (defaults are equivalent to current values in `Constants`):
  - rss.client.blockId.sequenceNoBits
  - rss.client.blockId.partitionIdBits
  - rss.client.blockId.taskAttemptIdBits
- Adds block id layout to two requests (default is layout with current values in `Constants`).

Default values have moved from `Constants` into `BlockIdLayout`. The following replacements exist:
- `PARTITION_ID_MAX_LENGTH`: `BlockIdLayout.DEFAULT.partitionIdBits`
- `TASK_ATTEMPT_ID_MAX_LENGTH`: `BlockIdLayout.DEFAULT.taskAttemptIdBits`
- `ATOMIC_INT_MAX_LENGTH`: `BlockIdLayout.DEFAULT.sequenceNoBits`

### Why are the changes needed?
The bit-lengths of sequence number, partition id and task attempt id in block id are defined in `common/src/main/java/org/apache/uniffle/common/util/Constants.java`. Changing these requires recompiling and redeploying the project. Making this configurable in `coordinator.conf`, `server.conf` or client-side would very useful.

Also see apache#1512, apache#749.

Fixes apache#731.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Tests.
  • Loading branch information
EnricoMi authored Mar 7, 2024
1 parent ec4251d commit dc890ff
Show file tree
Hide file tree
Showing 54 changed files with 1,459 additions and 462 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,24 @@
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.BlockId;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.Constants;

public class RssMRUtils {

private static final Logger LOG = LoggerFactory.getLogger(RssMRUtils.class);
private static final BlockIdLayout LAYOUT = BlockIdLayout.DEFAULT;
private static final int MAX_ATTEMPT_LENGTH = 6;
private static final int MAX_ATTEMPT_ID = (1 << MAX_ATTEMPT_LENGTH) - 1;
private static final int MAX_SEQUENCE_NO =
(1 << (Constants.ATOMIC_INT_MAX_LENGTH - MAX_ATTEMPT_LENGTH)) - 1;
(1 << (LAYOUT.sequenceNoBits - MAX_ATTEMPT_LENGTH)) - 1;

// Class TaskAttemptId have two field id and mapId, rss taskAttemptID have 21 bits,
// mapId is 19 bits, id is 2 bits. MR have a trick logic, taskAttemptId will increase
// 1000 * (appAttemptId - 1), so we will decrease it.
public static long convertTaskAttemptIdToLong(TaskAttemptID taskAttemptID, int appAttemptId) {
int lowBytes = taskAttemptID.getTaskID().getId();
if (lowBytes > Constants.MAX_TASK_ATTEMPT_ID) {
if (lowBytes > LAYOUT.maxTaskAttemptId) {
throw new RssException("TaskAttempt " + taskAttemptID + " low bytes " + lowBytes + " exceed");
}
if (appAttemptId < 1) {
Expand All @@ -64,16 +65,16 @@ public static long convertTaskAttemptIdToLong(TaskAttemptID taskAttemptID, int a
throw new RssException(
"TaskAttempt " + taskAttemptID + " high bytes " + highBytes + " exceed");
}
return BlockId.getBlockId(highBytes, 0, lowBytes);
return LAYOUT.getBlockId(highBytes, 0, lowBytes);
}

public static TaskAttemptID createMRTaskAttemptId(
JobID jobID, TaskType taskType, long rssTaskAttemptId, int appAttemptId) {
if (appAttemptId < 1) {
throw new RssException("appAttemptId " + appAttemptId + " is wrong");
}
TaskID taskID = new TaskID(jobID, taskType, BlockId.getTaskAttemptId(rssTaskAttemptId));
int id = BlockId.getSequenceNo(rssTaskAttemptId) + 1000 * (appAttemptId - 1);
TaskID taskID = new TaskID(jobID, taskType, LAYOUT.getTaskAttemptId(rssTaskAttemptId));
int id = LAYOUT.getSequenceNo(rssTaskAttemptId) + 1000 * (appAttemptId - 1);
return new TaskAttemptID(taskID, id);
}

Expand Down Expand Up @@ -227,8 +228,7 @@ public static String getString(Configuration rssJobConf, String key, String defa
}

public static long getBlockId(int partitionId, long taskAttemptId, int nextSeqNo) {
long attemptId =
taskAttemptId >> (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH);
long attemptId = taskAttemptId >> (LAYOUT.partitionIdBits + LAYOUT.taskAttemptIdBits);
if (attemptId < 0 || attemptId > MAX_ATTEMPT_ID) {
throw new RssException(
"Can't support attemptId [" + attemptId + "], the max value should be " + MAX_ATTEMPT_ID);
Expand All @@ -240,17 +240,15 @@ public static long getBlockId(int partitionId, long taskAttemptId, int nextSeqNo

int atomicInt = (int) ((nextSeqNo << MAX_ATTEMPT_LENGTH) + attemptId);
long taskId =
taskAttemptId
- (attemptId
<< (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH));
taskAttemptId - (attemptId << (LAYOUT.partitionIdBits + LAYOUT.taskAttemptIdBits));

return BlockId.getBlockId(atomicInt, partitionId, taskId);
return LAYOUT.getBlockId(atomicInt, partitionId, taskId);
}

public static long getTaskAttemptId(long blockId) {
int mapId = BlockId.getTaskAttemptId(blockId);
int attemptId = BlockId.getSequenceNo(blockId) & MAX_ATTEMPT_ID;
return BlockId.getBlockId(attemptId, 0, mapId);
int mapId = LAYOUT.getTaskAttemptId(blockId);
int attemptId = LAYOUT.getSequenceNo(blockId) & MAX_ATTEMPT_ID;
return LAYOUT.getBlockId(attemptId, 0, mapId);
}

public static int estimateTaskConcurrency(JobConf jobConf) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import org.apache.uniffle.client.util.RssClientConfig;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.storage.util.StorageType;

Expand Down Expand Up @@ -80,16 +81,16 @@ public void blockConvertTest() {

@Test
public void partitionIdConvertBlockTest() {
BlockIdLayout layout = BlockIdLayout.DEFAULT;
JobID jobID = new JobID();
TaskID taskId = new TaskID(jobID, TaskType.MAP, 233);
TaskAttemptID taskAttemptID = new TaskAttemptID(taskId, 1);
long taskAttemptId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID, 1);
long mask = (1L << Constants.PARTITION_ID_MAX_LENGTH) - 1;
long mask = (1L << layout.partitionIdBits) - 1;
for (int partitionId = 0; partitionId <= 3000; partitionId++) {
for (int seqNo = 0; seqNo <= 10; seqNo++) {
long blockId = RssMRUtils.getBlockId(partitionId, taskAttemptId, seqNo);
int newPartitionId =
Math.toIntExact((blockId >> Constants.TASK_ATTEMPT_ID_MAX_LENGTH) & mask);
int newPartitionId = Math.toIntExact((blockId >> layout.taskAttemptIdBits) & mask);
assertEquals(partitionId, newPartitionId);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.BlockId;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.ChecksumUtils;

public class WriteBufferManager extends MemoryConsumer {
Expand Down Expand Up @@ -97,6 +97,7 @@ public class WriteBufferManager extends MemoryConsumer {
private boolean memorySpillEnabled;
private int memorySpillTimeoutSec;
private boolean isRowBased;
private BlockIdLayout blockIdLayout;

public WriteBufferManager(
int shuffleId,
Expand Down Expand Up @@ -162,6 +163,7 @@ public WriteBufferManager(
this.sendSizeLimit = rssConf.get(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION);
this.memorySpillTimeoutSec = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_TIMEOUT);
this.memorySpillEnabled = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_ENABLED);
this.blockIdLayout = BlockIdLayout.from(rssConf);
}

/** add serialized columnar data directly when integrate with gluten */
Expand Down Expand Up @@ -329,7 +331,8 @@ protected ShuffleBlockInfo createShuffleBlock(int partitionId, WriterBuffer wb)
compressTime += System.currentTimeMillis() - start;
}
final long crc32 = ChecksumUtils.getCrc32(compressed);
final long blockId = BlockId.getBlockId(getNextSeqNo(partitionId), partitionId, taskAttemptId);
final long blockId =
blockIdLayout.getBlockId(getNextSeqNo(partitionId), partitionId, taskAttemptId);
uncompressedDataLen += data.length;
shuffleWriteMetrics.incBytesWritten(compressed.length);
// add memory to indicate bytes which will be sent to shuffle server
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Maps;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.MapOutputTracker;
Expand Down Expand Up @@ -51,8 +50,11 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac
private Method unregisterAllMapOutputMethod;
private Method registerShuffleMethod;

/** See static overload of this method. */
public abstract long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo);

/**
* Provides a task attempt id that is unique for a shuffle stage.
* Provides a task attempt id to be used in the block id, that is unique for a shuffle stage.
*
* <p>We are not using context.taskAttemptId() here as this is a monotonically increasing number
* that is unique across the entire Spark app which can reach very large numbers, which can
Expand All @@ -64,8 +66,7 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac
*
* @return a task attempt id unique for a shuffle stage
*/
@VisibleForTesting
protected static long getTaskAttemptId(
protected static long getTaskAttemptIdForBlockId(
int mapIndex, int attemptNo, int maxFailures, boolean speculation, int maxTaskAttemptIdBits) {
// attempt number is zero based: 0, 1, …, maxFailures-1
// max maxFailures < 1 is not allowed but for safety, we interpret that as maxFailures == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import org.apache.uniffle.common.ShufflePartitionedBlock;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.util.BlockId;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.apache.uniffle.storage.HadoopTestBase;
import org.apache.uniffle.storage.handler.api.ShuffleWriteHandler;
Expand Down Expand Up @@ -76,6 +76,7 @@ protected void writeTestData(
handler,
blockNum,
recordNum,
BlockIdLayout.DEFAULT,
expectedData,
blockIdBitmap,
keyPrefix,
Expand All @@ -88,6 +89,55 @@ protected void writeTestData(
ShuffleWriteHandler handler,
int blockNum,
int recordNum,
BlockIdLayout layout,
Map<String, String> expectedData,
Roaring64NavigableMap blockIdBitmap,
String keyPrefix,
Serializer serializer,
int partitionID)
throws Exception {
writeTestData(
handler,
blockNum,
recordNum,
layout,
expectedData,
blockIdBitmap,
keyPrefix,
serializer,
partitionID,
true);
}

protected void writeTestData(
ShuffleWriteHandler handler,
int blockNum,
int recordNum,
Map<String, String> expectedData,
Roaring64NavigableMap blockIdBitmap,
String keyPrefix,
Serializer serializer,
int partitionID,
boolean compress)
throws Exception {
writeTestData(
handler,
blockNum,
recordNum,
BlockIdLayout.DEFAULT,
expectedData,
blockIdBitmap,
keyPrefix,
serializer,
partitionID,
compress);
}

protected void writeTestData(
ShuffleWriteHandler handler,
int blockNum,
int recordNum,
BlockIdLayout layout,
Map<String, String> expectedData,
Roaring64NavigableMap blockIdBitmap,
String keyPrefix,
Expand All @@ -106,7 +156,7 @@ protected void writeTestData(
expectedData.put(key, value);
writeData(serializeStream, key, value);
}
long blockId = BlockId.getBlockId(atomicInteger.getAndIncrement(), partitionID, 0);
long blockId = layout.getBlockId(atomicInteger.getAndIncrement(), partitionID, 0);
blockIdBitmap.add(blockId);
blocks.add(createShuffleBlock(output.toBytes(), blockId, compress));
serializeStream.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
Expand All @@ -33,6 +34,9 @@
import org.apache.spark.shuffle.RssSparkConfig;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
Expand All @@ -42,9 +46,8 @@
import org.apache.uniffle.client.impl.ShuffleReadClientImpl;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.util.BlockId;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.storage.handler.impl.HadoopShuffleWriteHandler;
import org.apache.uniffle.storage.util.StorageType;

Expand All @@ -65,23 +68,30 @@ public class RssShuffleDataIteratorTest extends AbstractRssReaderTest {
private ShuffleServerInfo ssi1 = new ShuffleServerInfo("host1-0", "host1", 0);
private ShuffleServerInfo ssi2 = new ShuffleServerInfo("host2-0", "host2", 0);

@Test
public void readTest1() throws Exception {
public static Stream<Arguments> testBlockIdLayouts() {
return Stream.of(
Arguments.of(BlockIdLayout.DEFAULT), Arguments.of(BlockIdLayout.from(20, 21, 22)));
}

@ParameterizedTest
@MethodSource("testBlockIdLayouts")
public void readTest1(BlockIdLayout layout) throws Exception {
String basePath = HDFS_URI + "readTest1";
HadoopShuffleWriteHandler writeHandler =
new HadoopShuffleWriteHandler("appId", 0, 0, 1, basePath, ssi1.getId(), conf);

Map<String, String> expectedData = Maps.newHashMap();
Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0);
writeTestData(writeHandler, 2, 5, expectedData, blockIdBitmap, "key", KRYO_SERIALIZER, 0);
writeTestData(
writeHandler, 2, 5, layout, expectedData, blockIdBitmap, "key", KRYO_SERIALIZER, 0);

RssShuffleDataIterator rssShuffleDataIterator =
getDataIterator(basePath, blockIdBitmap, taskIdBitmap, Lists.newArrayList(ssi1));

validateResult(rssShuffleDataIterator, expectedData, 10);

blockIdBitmap.add(BlockId.getBlockId(Constants.MAX_SEQUENCE_NO, 0, 0));
blockIdBitmap.add(layout.getBlockId(layout.maxSequenceNo, 0, 0));
rssShuffleDataIterator =
getDataIterator(basePath, blockIdBitmap, taskIdBitmap, Lists.newArrayList(ssi1));
int recNum = 0;
Expand Down Expand Up @@ -270,7 +280,9 @@ public void readTest7() throws Exception {
}
fail(EXPECTED_EXCEPTION_MESSAGE);
} catch (Exception e) {
assertTrue(e.getMessage().startsWith("Unexpected crc value"));
assertTrue(
e.getMessage()
.startsWith("Unexpected crc value for blockId[0 (seq: 0, part: 0, task: 0)]"));
}

try {
Expand Down
Loading

0 comments on commit dc890ff

Please sign in to comment.