Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CELEBORN-1496] Differentiate map results with only different stageAttemptId #2609

Closed
wants to merge 13 commits into from
Closed
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.celeborn;

import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.scheduler.DAGScheduler;

public class SparkCommonUtils {
public static void validateAttemptConfig(SparkConf conf) throws IllegalArgumentException {
int maxStageAttempts =
conf.getInt(
"spark.stage.maxConsecutiveAttempts",
DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS());
// In Spark 2, the parameter is referred to as MAX_TASK_FAILURES, while in Spark 3, it has been
// changed to TASK_MAX_FAILURES. The default value for both is consistently set to 4.
int maxTaskAttempts = conf.getInt("spark.task.maxFailures", 4);
if (maxStageAttempts >= (1 << 15) || maxTaskAttempts >= (1 << 16)) {
// The map attemptId is a non-negative number constructed from
// both stageAttemptNumber and taskAttemptNumber.
// The high 16 bits of the map attemptId are used for the stageAttemptNumber,
// and the low 16 bits are used for the taskAttemptNumber.
// So spark.stage.maxConsecutiveAttempts should be less than 32768 (1 << 15)
// and spark.task.maxFailures should be less than 65536 (1 << 16).
throw new IllegalArgumentException(
"The spark.stage.maxConsecutiveAttempts should be less than 32768 (currently "
+ maxStageAttempts
+ ")"
+ "and spark.task.maxFailures should be less than 65536 (currently "
+ maxTaskAttempts
+ ").");
}
}

public static int getEncodedAttemptNumber(TaskContext context) {
return (context.stageAttemptNumber() << 16) | context.attemptNumber();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed earlier (I cant seem to find the ref :-) ) - please do submit a PR to Apache Spark as well for this - and ensure the communities can align on this assumption.

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public class HashBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
private final ShuffleWriteMetrics writeMetrics;
private final int shuffleId;
private final int mapId;
private final int encodedAttemptId;
private final TaskContext taskContext;
private final ShuffleClient shuffleClient;
private final int numMappers;
Expand Down Expand Up @@ -112,6 +113,7 @@ public HashBasedShuffleWriter(
this.mapId = mapId;
this.dep = handle.dependency();
this.shuffleId = shuffleId;
this.encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(taskContext);
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
Expand Down Expand Up @@ -146,7 +148,7 @@ public HashBasedShuffleWriter(
new DataPusher(
shuffleId,
mapId,
taskContext.attemptNumber(),
encodedAttemptId,
taskContext.taskAttemptId(),
numMappers,
numPartitions,
Expand Down Expand Up @@ -278,7 +280,7 @@ private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throw
shuffleClient.pushData(
shuffleId,
mapId,
taskContext.attemptNumber(),
encodedAttemptId,
partitionId,
buffer,
0,
Expand Down Expand Up @@ -323,7 +325,7 @@ private void close() throws IOException, InterruptedException {
// here we wait for all the in-flight batches to return which sent by dataPusher thread
dataPusher.waitOnTermination();
sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
shuffleClient.prepareForMergeData(shuffleId, mapId, taskContext.attemptNumber());
shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);

// merge and push residual data to reduce network traffic
// NB: since dataPusher thread have no in-flight data at this point,
Expand All @@ -335,7 +337,7 @@ private void close() throws IOException, InterruptedException {
shuffleClient.mergeData(
shuffleId,
mapId,
taskContext.attemptNumber(),
encodedAttemptId,
i,
sendBuffers[i],
0,
Expand All @@ -348,7 +350,7 @@ private void close() throws IOException, InterruptedException {
writeMetrics.incBytesWritten(bytesWritten);
}
}
shuffleClient.pushMergedData(shuffleId, mapId, taskContext.attemptNumber());
shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);

updateMapStatus();

Expand All @@ -357,7 +359,7 @@ private void close() throws IOException, InterruptedException {
sendOffsets = null;

long waitStartTime = System.nanoTime();
shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(), numMappers);
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);

BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
Expand Down Expand Up @@ -394,7 +396,7 @@ public Option<MapStatus> stop(boolean success) {
}
}
} finally {
shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public class SortBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
private final ShuffleWriteMetrics writeMetrics;
private final int shuffleId;
private final int mapId;
private final int encodedAttemptId;
private final TaskContext taskContext;
private final ShuffleClient shuffleClient;
private final int numMappers;
Expand Down Expand Up @@ -102,6 +103,7 @@ public SortBasedShuffleWriter(
this.mapId = taskContext.partitionId();
this.dep = dep;
this.shuffleId = shuffleId;
this.encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(taskContext);
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
Expand Down Expand Up @@ -130,7 +132,7 @@ public SortBasedShuffleWriter(
taskContext,
shuffleId,
mapId,
taskContext.attemptNumber(),
encodedAttemptId,
taskContext.taskAttemptId(),
numMappers,
numPartitions,
Expand Down Expand Up @@ -280,7 +282,7 @@ private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throw
shuffleClient.pushData(
shuffleId,
mapId,
taskContext.attemptNumber(),
encodedAttemptId,
partitionId,
buffer,
0,
Expand All @@ -298,12 +300,12 @@ private void close() throws IOException {
pusher.close();
writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);

shuffleClient.pushMergedData(shuffleId, mapId, taskContext.attemptNumber());
shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);

updateMapStatus();

long waitStartTime = System.nanoTime();
shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(), numMappers);
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
}

Expand Down Expand Up @@ -339,7 +341,7 @@ public Option<MapStatus> stop(boolean success) {
} catch (IOException e) {
return Option.apply(null);
} finally {
shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ public class SparkShuffleManager implements ShuffleManager {
private ExecutorShuffleIdTracker shuffleIdTracker = new ExecutorShuffleIdTracker();

public SparkShuffleManager(SparkConf conf, boolean isDriver) {
SparkCommonUtils.validateAttemptConfig(conf);
this.conf = conf;
this.isDriver = isDriver;
this.celebornConf = SparkUtils.fromSparkConf(conf);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class CelebornShuffleReader[K, C](
handle.extension)

private val exceptionRef = new AtomicReference[IOException]
private val encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(context)

override def read(): Iterator[Product2[K, C]] = {

Expand Down Expand Up @@ -96,7 +97,7 @@ class CelebornShuffleReader[K, C](
val inputStream = shuffleClient.readPartition(
shuffleId,
partitionId,
context.attemptNumber(),
encodedAttemptId,
startMapIndex,
endMapIndex,
metricsCallback)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.internal.SQLConf
import org.junit
import org.junit.Assert
import org.junit.runner.RunWith
import org.junit.runners.JUnit4

Expand Down Expand Up @@ -67,4 +68,33 @@ class SparkShuffleManagerSuite extends Logging {
sc.stop()
}

@junit.Test
def testWrongSparkConfMaxAttemptLimit(): Unit = {
val conf = new SparkConf().setIfMissing("spark.master", "local")
.setIfMissing(
"spark.shuffle.manager",
"org.apache.spark.shuffle.celeborn.SparkShuffleManager")
.set(s"spark.${CelebornConf.MASTER_ENDPOINTS.key}", "localhost:9097")
.set(s"spark.${CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key}", "false")
.set("spark.shuffle.service.enabled", "false")
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")

// default conf, will success
new SparkShuffleManager(conf, true)

conf
.set("spark.stage.maxConsecutiveAttempts", "32768")
.set("spark.task.maxFailures", "10")
try {
new SparkShuffleManager(conf, true)
Assert.fail()
} catch {
case e: IllegalArgumentException =>
Assert.assertTrue(
e.getMessage.contains("The spark.stage.maxConsecutiveAttempts should be less than 32768"))
case _: Throwable =>
Assert.fail()
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public class HashBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
private final ShuffleWriteMetricsReporter writeMetrics;
private final int shuffleId;
private final int mapId;
private final int encodedAttemptId;
private final TaskContext taskContext;
private final ShuffleClient shuffleClient;
private final int numMappers;
Expand Down Expand Up @@ -112,6 +113,7 @@ public HashBasedShuffleWriter(
this.mapId = taskContext.partitionId();
this.dep = handle.dependency();
this.shuffleId = shuffleId;
this.encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(taskContext);
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = metrics;
Expand Down Expand Up @@ -142,7 +144,7 @@ public HashBasedShuffleWriter(
new DataPusher(
shuffleId,
mapId,
taskContext.attemptNumber(),
encodedAttemptId,
taskContext.taskAttemptId(),
numMappers,
numPartitions,
Expand Down Expand Up @@ -278,7 +280,7 @@ protected void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) thr
shuffleClient.pushData(
shuffleId,
mapId,
taskContext.attemptNumber(),
encodedAttemptId,
partitionId,
buffer,
0,
Expand Down Expand Up @@ -342,7 +344,7 @@ protected void mergeData(int partitionId, byte[] buffer, int offset, int length)
shuffleClient.mergeData(
shuffleId,
mapId,
taskContext.attemptNumber(),
encodedAttemptId,
partitionId,
buffer,
offset,
Expand All @@ -358,14 +360,14 @@ private void close() throws IOException, InterruptedException {
long pushMergedDataTime = System.nanoTime();
dataPusher.waitOnTermination();
sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
shuffleClient.prepareForMergeData(shuffleId, mapId, taskContext.attemptNumber());
shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);
closeWrite();
shuffleClient.pushMergedData(shuffleId, mapId, taskContext.attemptNumber());
shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
writeMetrics.incWriteTime(System.nanoTime() - pushMergedDataTime);
updateRecordsWrittenMetrics();

long waitStartTime = System.nanoTime();
shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(), numMappers);
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);

BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
Expand Down Expand Up @@ -398,7 +400,7 @@ public Option<MapStatus> stop(boolean success) {
}
}
} finally {
shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
}
}

Expand Down
Loading
Loading