Skip to content

Commit

Permalink
[CELEBORN-1518] Add support for Apache Spark barrier stages
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Adds support for barrier stages.
This involves two aspects:
a) If there is a task failure when executing a barrier stage, all shuffle output for the stage attempt are discarded and ignored.
b) If there is a reexecution of a barrier stage (for ex, due to child stage getting a fetch failure), all shuffle output for the previous stage attempt are discarded and ignored.

This is similar to handling of indeterminate stages when `throwsFetchFailure` is `true`.

Note that this is supported only when `spark.celeborn.client.spark.fetch.throwsFetchFailure` is `true`

### Why are the changes needed?

As detailed in CELEBORN-1518, Celeborn currently does not support barrier stages; which is an essential functionality in Apache Spark which is widely in use by Spark users.
Enhancing Celeborn will allow its use for a wider set of Spark users.

### Does this PR introduce _any_ user-facing change?

Adds ability for Celeborn to support Apache Spark Barrier stages.

### How was this patch tested?

Existing tests, and additional tests (thanks to jiang13021 in apache#2609 - [see here](https://github.com/apache/celeborn/pull/2609/files#diff-e17f15fcca26ddfc412f0af159c784d72417b0f22598e1b1ebfcacd6d4c3ad35))

Closes apache#2639 from mridulm/fix-barrier-stage-reexecution.

Lead-authored-by: Mridul Muralidharan <mridul@gmail.com>
Co-authored-by: Mridul Muralidharan <mridulatgmail.com>
Signed-off-by: zky.zhoukeyong <zky.zhoukeyong@alibaba-inc.com>
  • Loading branch information
mridulm authored and waitinfuture committed Aug 12, 2024
1 parent a759efb commit 3234bef
Show file tree
Hide file tree
Showing 15 changed files with 447 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ public <K, V> ShuffleWriter<K, V> getWriter(
celebornConf,
h.userIdentifier(),
h.extension());
if (h.throwsFetchFailure()) {
SparkUtils.addFailureListenerIfBarrierTask(client, context, h);
}
int shuffleId = SparkUtils.celebornShuffleId(client, h, context, true);
shuffleIdTracker.track(h.shuffleId(), shuffleId);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import scala.Tuple2;

import org.apache.spark.BarrierTaskContext;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
Expand Down Expand Up @@ -130,7 +131,11 @@ public static int celebornShuffleId(
Boolean isWriter) {
if (handle.throwsFetchFailure()) {
String appShuffleIdentifier = getAppShuffleIdentifier(handle.shuffleId(), context);
return client.getShuffleId(handle.shuffleId(), appShuffleIdentifier, isWriter);
return client.getShuffleId(
handle.shuffleId(),
appShuffleIdentifier,
isWriter,
context instanceof BarrierTaskContext);
} else {
return handle.shuffleId();
}
Expand All @@ -157,4 +162,21 @@ public static <T> T instantiateClass(String className, SparkConf conf, Boolean i
}
}
}

// Adds a task failure listener which notifies lifecyclemanager when any
// task fails for a barrier stage
public static void addFailureListenerIfBarrierTask(
ShuffleClient shuffleClient, TaskContext taskContext, CelebornShuffleHandle<?, ?, ?> handle) {

if (!(taskContext instanceof BarrierTaskContext)) return;
int appShuffleId = handle.shuffleId();
String appShuffleIdentifier = SparkUtils.getAppShuffleIdentifier(appShuffleId, taskContext);

BarrierTaskContext barrierContext = (BarrierTaskContext) taskContext;
barrierContext.addTaskFailureListener(
(context, error) -> {
// whatever is the reason for failure, we notify lifecycle manager about the failure
shuffleClient.reportBarrierTaskFailure(appShuffleId, appShuffleIdentifier);
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.celeborn

import org.apache.spark.shuffle.FetchFailedException

import org.apache.celeborn.common.util.ExceptionMaker

object ExceptionMakerHelper {

val FETCH_FAILURE_ERROR_MSG = "Celeborn FetchFailure with shuffle id "

val SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER = new ExceptionMaker() {
override def makeFetchFailureException(
appShuffleId: Int,
shuffleId: Int,
partitionId: Int,
e: Exception): Exception = {
new FetchFailedException(
null,
appShuffleId,
-1,
partitionId,
FETCH_FAILURE_ERROR_MSG + appShuffleId + "/" + shuffleId,
e)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ public <K, V> ShuffleWriter<K, V> getWriter(
celebornConf,
h.userIdentifier(),
h.extension());
if (h.throwsFetchFailure()) {
SparkUtils.addFailureListenerIfBarrierTask(shuffleClient, context, h);
}
int shuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, true);
shuffleIdTracker.track(h.shuffleId(), shuffleId);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import scala.Tuple2;

import org.apache.spark.BarrierTaskContext;
import org.apache.spark.MapOutputTrackerMaster;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
Expand Down Expand Up @@ -107,7 +108,11 @@ public static int celebornShuffleId(
Boolean isWriter) {
if (handle.throwsFetchFailure()) {
String appShuffleIdentifier = getAppShuffleIdentifier(handle.shuffleId(), context);
return client.getShuffleId(handle.shuffleId(), appShuffleIdentifier, isWriter);
return client.getShuffleId(
handle.shuffleId(),
appShuffleIdentifier,
isWriter,
context instanceof BarrierTaskContext);
} else {
return handle.shuffleId();
}
Expand Down Expand Up @@ -274,4 +279,21 @@ public static void unregisterAllMapOutput(
throw new UnsupportedOperationException(
"unexpected! neither methods unregisterAllMapAndMergeOutput/unregisterAllMapOutput are found in MapOutputTrackerMaster");
}

// Adds a task failure listener which notifies lifecyclemanager when any
// task fails for a barrier stage
public static void addFailureListenerIfBarrierTask(
ShuffleClient shuffleClient, TaskContext taskContext, CelebornShuffleHandle<?, ?, ?> handle) {

if (!(taskContext instanceof BarrierTaskContext)) return;
int appShuffleId = handle.shuffleId();
String appShuffleIdentifier = SparkUtils.getAppShuffleIdentifier(appShuffleId, taskContext);

BarrierTaskContext barrierContext = (BarrierTaskContext) taskContext;
barrierContext.addTaskFailureListener(
(context, error) -> {
// whatever is the reason for failure, we notify lifecycle manager about the failure
shuffleClient.reportBarrierTaskFailure(appShuffleId, appShuffleIdentifier);
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.celeborn

import org.apache.spark.shuffle.FetchFailedException

import org.apache.celeborn.common.util.ExceptionMaker

object ExceptionMakerHelper {

val FETCH_FAILURE_ERROR_MSG = "Celeborn FetchFailure with shuffle id "

val SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER = new ExceptionMaker() {
override def makeFetchFailureException(
appShuffleId: Int,
shuffleId: Int,
partitionId: Int,
e: Exception): Exception = {
new FetchFailedException(
null,
appShuffleId,
-1,
-1,
partitionId,
FETCH_FAILURE_ERROR_MSG + appShuffleId + "/" + shuffleId,
e)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConverters._

import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext}
import org.apache.spark.celeborn.ExceptionMakerHelper
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.shuffle.{FetchFailedException, ShuffleReader, ShuffleReadMetricsReporter}
Expand Down Expand Up @@ -97,23 +98,6 @@ class CelebornShuffleReader[K, C](
}
}

val exceptionMaker = new ExceptionMaker() {
override def makeFetchFailureException(
appShuffleId: Int,
shuffleId: Int,
partitionId: Int,
e: Exception): Exception = {
new FetchFailedException(
null,
appShuffleId,
-1,
-1,
partitionId,
SparkUtils.FETCH_FAILURE_ERROR_MSG + appShuffleId + "/" + shuffleId,
e)
}
}

val startTime = System.currentTimeMillis()
val fetchTimeoutMs = conf.clientFetchTimeoutMs
val localFetchEnabled = conf.enableReadLocalShuffleFile
Expand Down Expand Up @@ -212,7 +196,8 @@ class CelebornShuffleReader[K, C](
context.attemptNumber(),
startMapIndex,
endMapIndex,
if (throwsFetchFailure) exceptionMaker else null,
if (throwsFetchFailure) ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER
else null,
locations,
streamHandlers,
fileGroups.mapAttempts,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ public abstract ConcurrentHashMap<Integer, PartitionLocation> getPartitionLocati

public abstract PushState getPushState(String mapKey);

public abstract int getShuffleId(int appShuffleId, String appShuffleIdentifier, boolean isWriter);
public abstract int getShuffleId(
int appShuffleId, String appShuffleIdentifier, boolean isWriter, boolean isBarrierStage);

/**
* report shuffle data fetch failure to LifecycleManager for special handling, eg, shuffle status
Expand All @@ -277,5 +278,11 @@ public abstract ConcurrentHashMap<Integer, PartitionLocation> getPartitionLocati
*/
public abstract boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId);

/**
* Report barrier task failure. When any barrier task fails, all (shuffle) output for that stage
* attempt is to be discarded, and spark will recompute the entire stage
*/
public abstract boolean reportBarrierTaskFailure(int appShuffleId, String appShuffleIdentifier);

public abstract TransportClientFactory getDataClientFactory();
}
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,8 @@ public PushState getPushState(String mapKey) {
}

@Override
public int getShuffleId(int appShuffleId, String appShuffleIdentifier, boolean isWriter) {
public int getShuffleId(
int appShuffleId, String appShuffleIdentifier, boolean isWriter, boolean isBarrierStage) {
return shuffleIdCache.computeIfAbsent(
appShuffleIdentifier,
(id) -> {
Expand All @@ -589,6 +590,7 @@ public int getShuffleId(int appShuffleId, String appShuffleIdentifier, boolean i
.setAppShuffleId(appShuffleId)
.setAppShuffleIdentifier(appShuffleIdentifier)
.setIsShuffleWriter(isWriter)
.setIsBarrierStage(isBarrierStage)
.build();
PbGetShuffleIdResponse pbGetShuffleIdResponse =
lifecycleManagerRef.askSync(
Expand All @@ -614,6 +616,20 @@ public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId) {
return pbReportShuffleFetchFailureResponse.getSuccess();
}

public boolean reportBarrierTaskFailure(int appShuffleId, String appShuffleIdentifier) {
PbReportBarrierStageAttemptFailure pbReportBarrierStageAttemptFailure =
PbReportBarrierStageAttemptFailure.newBuilder()
.setAppShuffleId(appShuffleId)
.setAppShuffleIdentifier(appShuffleIdentifier)
.build();
PbReportBarrierStageAttemptFailureResponse pbReportBarrierStageAttemptFailureResponse =
lifecycleManagerRef.askSync(
pbReportBarrierStageAttemptFailure,
conf.clientRpcRegisterShuffleAskTimeout(),
ClassTag$.MODULE$.apply(PbReportBarrierStageAttemptFailureResponse.class));
return pbReportBarrierStageAttemptFailureResponse.getSuccess();
}

private ConcurrentHashMap<Integer, PartitionLocation> registerShuffleInternal(
int shuffleId,
int numMappers,
Expand Down
Loading

0 comments on commit 3234bef

Please sign in to comment.