diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index bc7d1730966..83280c74803 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -191,6 +191,9 @@ public ShuffleWriter getWriter( celebornConf, h.userIdentifier(), h.extension()); + if (h.throwsFetchFailure()) { + SparkUtils.addFailureListenerIfBarrierTask(client, context, h); + } int shuffleId = SparkUtils.celebornShuffleId(client, h, context, true); shuffleIdTracker.track(h.shuffleId(), shuffleId); diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 0f03b5688f3..4f38c98152b 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -24,6 +24,7 @@ import scala.Tuple2; +import org.apache.spark.BarrierTaskContext; import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; import org.apache.spark.TaskContext; @@ -130,7 +131,11 @@ public static int celebornShuffleId( Boolean isWriter) { if (handle.throwsFetchFailure()) { String appShuffleIdentifier = getAppShuffleIdentifier(handle.shuffleId(), context); - return client.getShuffleId(handle.shuffleId(), appShuffleIdentifier, isWriter); + return client.getShuffleId( + handle.shuffleId(), + appShuffleIdentifier, + isWriter, + context instanceof BarrierTaskContext); } else { return handle.shuffleId(); } @@ -157,4 +162,21 @@ public static T instantiateClass(String className, SparkConf conf, Boolean i } } } + + // Adds a task failure listener which notifies lifecyclemanager when any + // task fails for a barrier stage + public static void addFailureListenerIfBarrierTask( + ShuffleClient shuffleClient, TaskContext taskContext, CelebornShuffleHandle handle) { + + if (!(taskContext instanceof BarrierTaskContext)) return; + int appShuffleId = handle.shuffleId(); + String appShuffleIdentifier = SparkUtils.getAppShuffleIdentifier(appShuffleId, taskContext); + + BarrierTaskContext barrierContext = (BarrierTaskContext) taskContext; + barrierContext.addTaskFailureListener( + (context, error) -> { + // whatever is the reason for failure, we notify lifecycle manager about the failure + shuffleClient.reportBarrierTaskFailure(appShuffleId, appShuffleIdentifier); + }); + } } diff --git a/client-spark/spark-2/src/main/scala/org/apache/spark/celeborn/ExceptionMakerHelper.scala b/client-spark/spark-2/src/main/scala/org/apache/spark/celeborn/ExceptionMakerHelper.scala new file mode 100644 index 00000000000..6456f3f44f9 --- /dev/null +++ b/client-spark/spark-2/src/main/scala/org/apache/spark/celeborn/ExceptionMakerHelper.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.celeborn + +import org.apache.spark.shuffle.FetchFailedException + +import org.apache.celeborn.common.util.ExceptionMaker + +object ExceptionMakerHelper { + + val FETCH_FAILURE_ERROR_MSG = "Celeborn FetchFailure with shuffle id " + + val SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER = new ExceptionMaker() { + override def makeFetchFailureException( + appShuffleId: Int, + shuffleId: Int, + partitionId: Int, + e: Exception): Exception = { + new FetchFailedException( + null, + appShuffleId, + -1, + partitionId, + FETCH_FAILURE_ERROR_MSG + appShuffleId + "/" + shuffleId, + e) + } + } +} diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 2d71ece85bb..4c84e9d5360 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -243,6 +243,9 @@ public ShuffleWriter getWriter( celebornConf, h.userIdentifier(), h.extension()); + if (h.throwsFetchFailure()) { + SparkUtils.addFailureListenerIfBarrierTask(shuffleClient, context, h); + } int shuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, true); shuffleIdTracker.track(h.shuffleId(), shuffleId); diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 3d97c1b98b4..47317474e8d 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -21,6 +21,7 @@ import scala.Tuple2; +import org.apache.spark.BarrierTaskContext; import org.apache.spark.MapOutputTrackerMaster; import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; @@ -107,7 +108,11 @@ public static int celebornShuffleId( Boolean isWriter) { if (handle.throwsFetchFailure()) { String appShuffleIdentifier = getAppShuffleIdentifier(handle.shuffleId(), context); - return client.getShuffleId(handle.shuffleId(), appShuffleIdentifier, isWriter); + return client.getShuffleId( + handle.shuffleId(), + appShuffleIdentifier, + isWriter, + context instanceof BarrierTaskContext); } else { return handle.shuffleId(); } @@ -274,4 +279,21 @@ public static void unregisterAllMapOutput( throw new UnsupportedOperationException( "unexpected! neither methods unregisterAllMapAndMergeOutput/unregisterAllMapOutput are found in MapOutputTrackerMaster"); } + + // Adds a task failure listener which notifies lifecyclemanager when any + // task fails for a barrier stage + public static void addFailureListenerIfBarrierTask( + ShuffleClient shuffleClient, TaskContext taskContext, CelebornShuffleHandle handle) { + + if (!(taskContext instanceof BarrierTaskContext)) return; + int appShuffleId = handle.shuffleId(); + String appShuffleIdentifier = SparkUtils.getAppShuffleIdentifier(appShuffleId, taskContext); + + BarrierTaskContext barrierContext = (BarrierTaskContext) taskContext; + barrierContext.addTaskFailureListener( + (context, error) -> { + // whatever is the reason for failure, we notify lifecycle manager about the failure + shuffleClient.reportBarrierTaskFailure(appShuffleId, appShuffleIdentifier); + }); + } } diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/ExceptionMakerHelper.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/ExceptionMakerHelper.scala new file mode 100644 index 00000000000..72d9019f8ee --- /dev/null +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/ExceptionMakerHelper.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.celeborn + +import org.apache.spark.shuffle.FetchFailedException + +import org.apache.celeborn.common.util.ExceptionMaker + +object ExceptionMakerHelper { + + val FETCH_FAILURE_ERROR_MSG = "Celeborn FetchFailure with shuffle id " + + val SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER = new ExceptionMaker() { + override def makeFetchFailureException( + appShuffleId: Int, + shuffleId: Int, + partitionId: Int, + e: Exception): Exception = { + new FetchFailedException( + null, + appShuffleId, + -1, + -1, + partitionId, + FETCH_FAILURE_ERROR_MSG + appShuffleId + "/" + shuffleId, + e) + } + } +} diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 66126c70e78..74a4daeb053 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -25,6 +25,7 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext} +import org.apache.spark.celeborn.ExceptionMakerHelper import org.apache.spark.internal.Logging import org.apache.spark.serializer.SerializerInstance import org.apache.spark.shuffle.{FetchFailedException, ShuffleReader, ShuffleReadMetricsReporter} @@ -97,23 +98,6 @@ class CelebornShuffleReader[K, C]( } } - val exceptionMaker = new ExceptionMaker() { - override def makeFetchFailureException( - appShuffleId: Int, - shuffleId: Int, - partitionId: Int, - e: Exception): Exception = { - new FetchFailedException( - null, - appShuffleId, - -1, - -1, - partitionId, - SparkUtils.FETCH_FAILURE_ERROR_MSG + appShuffleId + "/" + shuffleId, - e) - } - } - val startTime = System.currentTimeMillis() val fetchTimeoutMs = conf.clientFetchTimeoutMs val localFetchEnabled = conf.enableReadLocalShuffleFile @@ -212,7 +196,8 @@ class CelebornShuffleReader[K, C]( context.attemptNumber(), startMapIndex, endMapIndex, - if (throwsFetchFailure) exceptionMaker else null, + if (throwsFetchFailure) ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER + else null, locations, streamHandlers, fileGroups.mapAttempts, diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 0738368ff66..07ce7b10e88 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -268,7 +268,8 @@ public abstract ConcurrentHashMap getPartitionLocati public abstract PushState getPushState(String mapKey); - public abstract int getShuffleId(int appShuffleId, String appShuffleIdentifier, boolean isWriter); + public abstract int getShuffleId( + int appShuffleId, String appShuffleIdentifier, boolean isWriter, boolean isBarrierStage); /** * report shuffle data fetch failure to LifecycleManager for special handling, eg, shuffle status @@ -277,5 +278,11 @@ public abstract ConcurrentHashMap getPartitionLocati */ public abstract boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId); + /** + * Report barrier task failure. When any barrier task fails, all (shuffle) output for that stage + * attempt is to be discarded, and spark will recompute the entire stage + */ + public abstract boolean reportBarrierTaskFailure(int appShuffleId, String appShuffleIdentifier); + public abstract TransportClientFactory getDataClientFactory(); } diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 1f22acb9682..ad8fe3063f3 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -580,7 +580,8 @@ public PushState getPushState(String mapKey) { } @Override - public int getShuffleId(int appShuffleId, String appShuffleIdentifier, boolean isWriter) { + public int getShuffleId( + int appShuffleId, String appShuffleIdentifier, boolean isWriter, boolean isBarrierStage) { return shuffleIdCache.computeIfAbsent( appShuffleIdentifier, (id) -> { @@ -589,6 +590,7 @@ public int getShuffleId(int appShuffleId, String appShuffleIdentifier, boolean i .setAppShuffleId(appShuffleId) .setAppShuffleIdentifier(appShuffleIdentifier) .setIsShuffleWriter(isWriter) + .setIsBarrierStage(isBarrierStage) .build(); PbGetShuffleIdResponse pbGetShuffleIdResponse = lifecycleManagerRef.askSync( @@ -614,6 +616,20 @@ public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId) { return pbReportShuffleFetchFailureResponse.getSuccess(); } + public boolean reportBarrierTaskFailure(int appShuffleId, String appShuffleIdentifier) { + PbReportBarrierStageAttemptFailure pbReportBarrierStageAttemptFailure = + PbReportBarrierStageAttemptFailure.newBuilder() + .setAppShuffleId(appShuffleId) + .setAppShuffleIdentifier(appShuffleIdentifier) + .build(); + PbReportBarrierStageAttemptFailureResponse pbReportBarrierStageAttemptFailureResponse = + lifecycleManagerRef.askSync( + pbReportBarrierStageAttemptFailure, + conf.clientRpcRegisterShuffleAskTimeout(), + ClassTag$.MODULE$.apply(PbReportBarrierStageAttemptFailureResponse.class)); + return pbReportBarrierStageAttemptFailureResponse.getSuccess(); + } + private ConcurrentHashMap registerShuffleInternal( int shuffleId, int numMappers, diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index d55aee47d74..82d92d85316 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -408,8 +408,15 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends val appShuffleId = pb.getAppShuffleId val appShuffleIdentifier = pb.getAppShuffleIdentifier val isWriter = pb.getIsShuffleWriter - logDebug(s"Received GetShuffleId request, appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter.") - handleGetShuffleIdForApp(context, appShuffleId, appShuffleIdentifier, isWriter) + val isBarrierStage = pb.getIsBarrierStage + logDebug(s"Received GetShuffleId request, appShuffleId $appShuffleId " + + s"appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter isBarrier $isBarrierStage.") + handleGetShuffleIdForApp( + context, + appShuffleId, + appShuffleIdentifier, + isWriter, + isBarrierStage) case pb: PbReportShuffleFetchFailure => val appShuffleId = pb.getAppShuffleId @@ -417,6 +424,13 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends logDebug(s"Received ReportShuffleFetchFailure request, appShuffleId $appShuffleId shuffleId $shuffleId") handleReportShuffleFetchFailure(context, appShuffleId, shuffleId) + case pb: PbReportBarrierStageAttemptFailure => + val appShuffleId = pb.getAppShuffleId + val appShuffleIdentifier = pb.getAppShuffleIdentifier + logDebug(s"Received ReportBarrierStageAttemptFailure request, appShuffleId $appShuffleId, " + + s"appShuffleIdentifier $appShuffleIdentifier") + handleReportBarrierStageAttemptFailure(context, appShuffleId, appShuffleIdentifier) + case pb: PbApplicationMetaRequest => logDebug(s"Received request for meta info ${pb.getAppId}.") if (applicationMeta == null) { @@ -790,7 +804,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends context: RpcCallContext, appShuffleId: Int, appShuffleIdentifier: String, - isWriter: Boolean): Unit = { + isWriter: Boolean, + isBarrierStage: Boolean): Unit = { val shuffleIds = if (isWriter) { shuffleIdMapping.computeIfAbsent( @@ -831,7 +846,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends case None => Option(appShuffleDeterminateMap.get(appShuffleId)).map { determinate => val candidateShuffle = - if (determinate) + // For barrier stages, all tasks are re-executed when it is re-run : similar to indeterminate stage. + // So if a barrier stage is getting reexecuted, previous stage/attempt needs to + // be cleaned up as it is entirely unusuable + if (determinate && !isBarrierStage) shuffleIds.values.toSeq.reverse.find(e => e._2 == true) else None @@ -842,6 +860,14 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends logInfo(s"reuse existing shuffleId $id for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier") id } else { + if (isBarrierStage) { + // unregister previous shuffle(s) which are still valid + val mapUpdates = shuffleIds.filter(_._2._2).map { kv => + unregisterShuffle(kv._2._1) + kv._1 -> (kv._2._1, false) + } + shuffleIds ++= mapUpdates + } val newShuffleId = shuffleIdGenerator.getAndIncrement() logInfo(s"generate new shuffleId $newShuffleId for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier") shuffleIds.put(appShuffleIdentifier, (newShuffleId, true)) @@ -855,7 +881,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends s"unexpected! unknown appShuffleId $appShuffleId when checking shuffle deterministic level")) } } else { - shuffleIds.values.map(v => v._1).toSeq.reverse.find(isAllMaptaskEnd) match { + shuffleIds.values.filter(v => v._2).map(v => v._1).toSeq.reverse.find( + isAllMaptaskEnd) match { case Some(shuffleId) => val pbGetShuffleIdResponse = { logDebug( @@ -885,20 +912,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends shuffleIds.find(e => e._2._1 == shuffleId) match { case Some((appShuffleIdentifier, (shuffleId, true))) => logInfo(s"handle fetch failure for appShuffleId $appShuffleId shuffleId $shuffleId") - appShuffleTrackerCallback match { - case Some(callback) => - try { - callback.accept(appShuffleId) - } catch { - case t: Throwable => - logError(t.toString) - ret = false - } - shuffleIds.put(appShuffleIdentifier, (shuffleId, false)) - case None => - throw new UnsupportedOperationException( - "unexpected! appShuffleTrackerCallback is not registered") - } + ret = invokeAppShuffleTrackerCallback(appShuffleId) + shuffleIds.put(appShuffleIdentifier, (shuffleId, false)) case Some((appShuffleIdentifier, (shuffleId, false))) => logInfo( s"Ignoring fetch failure from appShuffleIdentifier $appShuffleIdentifier shuffleId $shuffleId, " + @@ -913,6 +928,54 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends context.reply(pbReportShuffleFetchFailureResponse) } + private def handleReportBarrierStageAttemptFailure( + context: RpcCallContext, + appShuffleId: Int, + appShuffleIdentifier: String): Unit = { + + val shuffleIds = shuffleIdMapping.get(appShuffleId) + if (shuffleIds == null) { + throw new UnsupportedOperationException(s"unexpected! unknown appShuffleId $appShuffleId") + } + var ret = true + shuffleIds.synchronized { + shuffleIds.get(appShuffleIdentifier) match { + case Some((shuffleId, true)) => + ret = invokeAppShuffleTrackerCallback(appShuffleId) + unregisterShuffle(shuffleId) + shuffleIds.put(appShuffleIdentifier, (shuffleId, false)) + case Some((shuffleId, false)) => + // older entry, already handled + logInfo( + s"Ignoring failure from barrier task for appShuffleIdentifier $appShuffleIdentifier " + + s"shuffleId $shuffleId for appShuffleId $appShuffleId, already handled") + case None => + throw new UnsupportedOperationException( + s"unexpected! unknown appShuffleId $appShuffleId for appShuffleIdentifier = $appShuffleIdentifier") + } + } + val pbReportBarrierStageAttemptFailureResponse = + PbReportBarrierStageAttemptFailureResponse.newBuilder().setSuccess(ret).build() + context.reply(pbReportBarrierStageAttemptFailureResponse) + } + + private def invokeAppShuffleTrackerCallback(appShuffleId: Int): Boolean = { + appShuffleTrackerCallback match { + case Some(callback) => + try { + callback.accept(appShuffleId) + true + } catch { + case t: Throwable => + logError(t.toString) + false + } + case None => + throw new UnsupportedOperationException( + "unexpected! appShuffleTrackerCallback is not registered") + } + } + private def handleStageEnd(shuffleId: Int): Unit = { // check whether shuffle has registered if (!registeredShuffle.contains(shuffleId)) { diff --git a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java index cae634fa748..49b6b5c546a 100644 --- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -173,7 +173,8 @@ public PushState getPushState(String mapKey) { } @Override - public int getShuffleId(int appShuffleId, String appShuffleIdentifier, boolean isWriter) { + public int getShuffleId( + int appShuffleId, String appShuffleIdentifier, boolean isWriter, boolean isBarrierStage) { return appShuffleId; } @@ -182,6 +183,10 @@ public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId) { return true; } + public boolean reportBarrierTaskFailure(int appShuffleId, String appShuffleIdentifier) { + return true; + } + @Override public TransportClientFactory getDataClientFactory() { return null; diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java index 1d684b2178a..aa5d7ad51e8 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java @@ -87,6 +87,10 @@ public T getParsedPayload() throws InvalidProtoco return (T) PbReportShuffleFetchFailure.parseFrom(payload); case REPORT_SHUFFLE_FETCH_FAILURE_RESPONSE_VALUE: return (T) PbReportShuffleFetchFailureResponse.parseFrom(payload); + case REPORT_BARRIER_STAGE_ATTEMPT_FAILURE_VALUE: + return (T) PbReportBarrierStageAttemptFailure.parseFrom(payload); + case REPORT_BARRIER_STAGE_ATTEMPT_FAILURE_RESPONSE_VALUE: + return (T) PbReportBarrierStageAttemptFailureResponse.parseFrom(payload); case SASL_REQUEST_VALUE: return (T) PbSaslRequest.parseFrom(payload); case AUTHENTICATION_INITIATION_REQUEST_VALUE: diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index 2b8432c44c5..ae57d4cda74 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -103,6 +103,8 @@ enum MessageType { BATCH_OPEN_STREAM = 80; BATCH_OPEN_STREAM_RESPONSE = 81; REPORT_WORKER_DECOMMISSION = 82; + REPORT_BARRIER_STAGE_ATTEMPT_FAILURE = 83; + REPORT_BARRIER_STAGE_ATTEMPT_FAILURE_RESPONSE = 84; } enum StreamType { @@ -368,6 +370,7 @@ message PbGetShuffleId { int32 appShuffleId = 1; string appShuffleIdentifier = 2; bool isShuffleWriter = 3; + bool isBarrierStage = 4; } message PbGetShuffleIdResponse { @@ -383,6 +386,15 @@ message PbReportShuffleFetchFailureResponse { bool success = 1; } +message PbReportBarrierStageAttemptFailure { + int32 appShuffleId = 1; + string appShuffleIdentifier = 2; +} + +message PbReportBarrierStageAttemptFailureResponse { + bool success = 1; +} + message PbUnregisterShuffle { string appId = 1; int32 shuffleId = 2; diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index c94744bb434..b727e1b180b 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -530,6 +530,14 @@ object ControlMessages extends Logging { case pb: PbReportShuffleFetchFailureResponse => new TransportMessage(MessageType.REPORT_SHUFFLE_FETCH_FAILURE_RESPONSE, pb.toByteArray) + case pb: PbReportBarrierStageAttemptFailure => + new TransportMessage(MessageType.REPORT_BARRIER_STAGE_ATTEMPT_FAILURE, pb.toByteArray) + + case pb: PbReportBarrierStageAttemptFailureResponse => + new TransportMessage( + MessageType.REPORT_BARRIER_STAGE_ATTEMPT_FAILURE_RESPONSE, + pb.toByteArray) + case HeartbeatFromWorker( host, rpcPort, @@ -1294,6 +1302,12 @@ object ControlMessages extends Logging { case APPLICATION_META_REQUEST_VALUE => PbApplicationMetaRequest.parseFrom(message.getPayload) + + case REPORT_BARRIER_STAGE_ATTEMPT_FAILURE_VALUE => + PbReportBarrierStageAttemptFailure.parseFrom(message.getPayload) + + case REPORT_BARRIER_STAGE_ATTEMPT_FAILURE_RESPONSE_VALUE => + PbReportBarrierStageAttemptFailureResponse.parseFrom(message.getPayload) } } } diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala index 233d4c50c3f..f3cd382118c 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala @@ -17,10 +17,12 @@ package org.apache.celeborn.tests.spark -import java.io.File +import java.io.{File, IOException} import java.util.concurrent.atomic.AtomicBoolean -import org.apache.spark.{SparkConf, SparkContextHelper, TaskContext} +import org.apache.spark.{BarrierTaskContext, ShuffleDependency, SparkConf, SparkContextHelper, SparkException, TaskContext} +import org.apache.spark.celeborn.ExceptionMakerHelper +import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkShuffleManager, SparkUtils, TestCelebornShuffleManager} import org.apache.spark.sql.SparkSession @@ -102,7 +104,6 @@ class CelebornFetchFailureSuite extends AnyFunSuite .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) .config("spark.sql.shuffle.partitions", 2) .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) - .config("spark.celeborn.shuffle.enabled", "true") .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true") .config( "spark.shuffle.manager", @@ -144,7 +145,6 @@ class CelebornFetchFailureSuite extends AnyFunSuite .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) .config("spark.sql.shuffle.partitions", 2) .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) - .config("spark.celeborn.shuffle.enabled", "true") .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "false") .getOrCreate() @@ -177,7 +177,6 @@ class CelebornFetchFailureSuite extends AnyFunSuite .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) .config("spark.sql.shuffle.partitions", 2) .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) - .config("spark.celeborn.shuffle.enabled", "true") .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true") .config( "spark.shuffle.manager", @@ -209,7 +208,6 @@ class CelebornFetchFailureSuite extends AnyFunSuite .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) .config("spark.sql.shuffle.partitions", 2) .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) - .config("spark.celeborn.shuffle.enabled", "true") .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true") .config( "spark.shuffle.manager", @@ -250,7 +248,6 @@ class CelebornFetchFailureSuite extends AnyFunSuite .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) .config("spark.sql.shuffle.partitions", 2) .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) - .config("spark.celeborn.shuffle.enabled", "true") .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true") .config( "spark.shuffle.manager", @@ -282,7 +279,6 @@ class CelebornFetchFailureSuite extends AnyFunSuite .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) .config("spark.sql.shuffle.partitions", 2) .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) - .config("spark.celeborn.shuffle.enabled", "true") .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true") .getOrCreate() @@ -297,4 +293,162 @@ class CelebornFetchFailureSuite extends AnyFunSuite sparkSession.stop() } } + + test(s"celeborn spark integration test - resubmit an unordered barrier stage with throwsFetchFailure enabled") { + val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2]") + val sparkSession = SparkSession.builder() + .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) + .config("spark.sql.shuffle.partitions", 2) + .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) + .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true") + .config("spark.celeborn.client.push.buffer.max.size", 0) + .config( + "spark.shuffle.manager", + "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") + .getOrCreate() + + try { + val sc = sparkSession.sparkContext + val rdd1 = sc + .parallelize(0 until 10000, 2) + .map(v => (v, v)) + .groupByKey() + .barrier() + .mapPartitions { + it => + val context = BarrierTaskContext.get() + if (context.stageAttemptNumber() == 0 && context.partitionId() == 0) { + Thread.sleep(3000) + throw new RuntimeException("failed") + } + if (context.stageAttemptNumber() > 0) { + it.toBuffer.reverseIterator + } else { + it + } + } + val rdd2 = rdd1.map(v => (v._2, v._1)).groupByKey() + val result = rdd2.collect() + result.foreach { + elem => + assert(elem._1.size == elem._2.size) + } + } finally { + sparkSession.stop() + } + } + + test(s"celeborn spark integration test - fetch failure in child of an unordered barrier stage with throwsFetchFailure enabled") { + val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2]") + val sparkSession = SparkSession.builder() + .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) + .config("spark.sql.shuffle.partitions", 2) + .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) + .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true") + .config("spark.celeborn.client.push.buffer.max.size", 0) + .config( + "spark.shuffle.manager", + "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") + .getOrCreate() + + try { + val sc = sparkSession.sparkContext + val inputGroupedRdd = sc + .parallelize(0 until 10000, 2) + .map(v => (v, v)) + .groupByKey() + val rdd1 = inputGroupedRdd + .barrier() + .mapPartitions(it => it) + val groupedRdd = rdd1.map(v => (v._2, v._1)).groupByKey() + val appShuffleId = findAppShuffleId(groupedRdd) + assert(findAppShuffleId(groupedRdd) != findAppShuffleId(inputGroupedRdd)) + val rdd2 = groupedRdd.mapPartitions { iter => + val context = TaskContext.get() + if (context.stageAttemptNumber() == 0 && context.partitionId() == 0) { + throw ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER.makeFetchFailureException( + appShuffleId, + -1, + context.partitionId(), + new IOException("forced")) + } + iter + } + val result = rdd2.collect() + result.foreach { + elem => + assert(elem._1.size == elem._2.size) + } + } finally { + sparkSession.stop() + } + } + + test(s"celeborn spark integration test - resubmit a failed barrier stage across jobs") { + val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2]") + val sparkSession = SparkSession.builder() + .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) + .config("spark.sql.shuffle.partitions", 2) + .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) + .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true") + .config("spark.celeborn.client.push.buffer.max.size", 0) + .config("spark.stage.maxConsecutiveAttempts", "1") + .config("spark.stage.maxAttempts", "1") + .config( + "spark.shuffle.manager", + "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") + .getOrCreate() + + // trigger failure + CelebornFetchFailureSuite.triggerFailure.set(true) + + try { + val sc = sparkSession.sparkContext + val rdd1 = sc + .parallelize(0 until 10000, 2) + .map(v => (v, v)) + .groupByKey() + .barrier() + .mapPartitions { + it => + val context = BarrierTaskContext.get() + if (context.partitionId() == 0 && CelebornFetchFailureSuite.triggerFailure.get()) { + Thread.sleep(3000) + throw new RuntimeException("failed") + } + if (CelebornFetchFailureSuite.triggerFailure.get()) { + it + } else { + it.toBuffer.reverseIterator + } + } + val rdd2 = rdd1.map(v => (v._2, v._1)).groupByKey() + assertThrows[SparkException] { + rdd2.collect() + } + + // Now, allow it to succeed + CelebornFetchFailureSuite.triggerFailure.set(false) + val result = rdd2.collect() + result.foreach { + elem => + assert(elem._1.size == elem._2.size) + } + } finally { + sparkSession.stop() + } + } + + private def findAppShuffleId(rdd: RDD[_]): Int = { + val deps = rdd.dependencies + if (deps.size != 1 && !deps.head.isInstanceOf[ShuffleDependency[_, _, _]]) { + throw new IllegalArgumentException("Expected an RDD with shuffle dependency: " + rdd) + } + + deps.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId + } +} + +object CelebornFetchFailureSuite { + private val triggerFailure = new AtomicBoolean(false) }