Skip to content

Commit 5e3615d

Browse files
fenzhuGitHub Enterprise
authored andcommitted
[CARMEL-7589][CARMEL-1376][CARMEL-4216] Limit the max numbers of task… (apache#216)
[CARMEL-7589][CARMEL-1376][CARMEL-4216] Limit the max numbers of tasks that one stage could generate
1 parent 4c027c5 commit 5e3615d

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import org.apache.spark.errors.SparkCoreErrors
4141
import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
4242
import org.apache.spark.internal.Logging
4343
import org.apache.spark.internal.config
44-
import org.apache.spark.internal.config.{JOB_GROUP_MAX_SHUFFLE_SIZE, RDD_CACHE_VISIBILITY_TRACKING_ENABLED, REMOVE_EXECUTOR_ON_FETCH_FAILURE, TASK_SUBMISSION_ASYNC}
44+
import org.apache.spark.internal.config.{JOB_GROUP_MAX_SHUFFLE_SIZE, RDD_CACHE_VISIBILITY_TRACKING_ENABLED, RDD_MAX_PARTITIONS, REMOVE_EXECUTOR_ON_FETCH_FAILURE, TASK_SUBMISSION_ASYNC}
4545
import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY
4646
import org.apache.spark.network.shuffle.{BlockStoreClient, MergeFinalizerListener}
4747
import org.apache.spark.network.shuffle.protocol.MergeStatuses
@@ -225,6 +225,8 @@ private[spark] class DAGScheduler(
225225
/** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
226226
private val disallowStageRetryForTest = sc.getConf.get(TEST_NO_STAGE_RETRY)
227227

228+
private val MAX_PARTITIONS_IN_STAGE = sc.getConf.get(RDD_MAX_PARTITIONS)
229+
228230
private val removeExecutorOnFetchFailure = sc.getConf.get(REMOVE_EXECUTOR_ON_FETCH_FAILURE)
229231

230232
private val shouldMergeResourceProfiles = sc.getConf.get(config.RESOURCE_PROFILE_MERGE_CONFLICTS)
@@ -674,6 +676,10 @@ private[spark] class DAGScheduler(
674676
checkBarrierStageWithNumSlots(rdd, resourceProfile)
675677
checkBarrierStageWithRDDChainPattern(rdd, rdd.getNumPartitions)
676678
val numTasks = rdd.partitions.length
679+
if (numTasks > MAX_PARTITIONS_IN_STAGE) {
680+
throw new SparkException(s"RDD Partitions have reached the max limitation " +
681+
s"$MAX_PARTITIONS_IN_STAGE, increase ${RDD_MAX_PARTITIONS.key} to work around.")
682+
}
677683
val parents = getOrCreateParentStages(shuffleDeps, jobId)
678684
val id = nextStageId.getAndIncrement()
679685
val stage = new ShuffleMapStage(
@@ -809,6 +815,12 @@ private[spark] class DAGScheduler(
809815
checkBarrierStageWithRDDChainPattern(rdd, partitions.toSet.size)
810816
val parents = getOrCreateParentStages(shuffleDeps, jobId)
811817
val id = nextStageId.getAndIncrement()
818+
// Use `partitions.length` instead of `rdd.partitions.length` to
819+
// skip SELECT * table LIMIT operation
820+
if (partitions.length > MAX_PARTITIONS_IN_STAGE) {
821+
throw new SparkException(s"RDD Partitions have reached the max limitation " +
822+
s"$MAX_PARTITIONS_IN_STAGE, increase ${RDD_MAX_PARTITIONS.key} to work around.")
823+
}
812824
val stage = new ResultStage(id, rdd, func, partitions, parents, jobId,
813825
callSite, resourceProfile.id, resultSpillContext)
814826
stageIdToStage(id) = stage
@@ -1324,6 +1336,10 @@ private[spark] class DAGScheduler(
13241336
if (rdd.partitions.length == 0) {
13251337
throw SparkCoreErrors.cannotRunSubmitMapStageOnZeroPartitionRDDError()
13261338
}
1339+
if (rdd.partitions.length > MAX_PARTITIONS_IN_STAGE) {
1340+
throw new SparkException(s"RDD Partitions have reached the max limitation " +
1341+
s"$MAX_PARTITIONS_IN_STAGE, increase ${RDD_MAX_PARTITIONS.key} to work around.")
1342+
}
13271343

13281344
// SPARK-23626: `RDD.getPartitions()` can be slow, so we eagerly compute
13291345
// `.partitions` on every RDD in the DAG to ensure that `getPartitions()`

core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import org.apache.spark._
4141
import org.apache.spark.broadcast.BroadcastManager
4242
import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
4343
import org.apache.spark.internal.config
44-
import org.apache.spark.internal.config.{SCHEDULER_ANALYTICS_TASK_SCHEDULER, TASK_SUBMISSION_ASYNC, Tests}
44+
import org.apache.spark.internal.config.{RDD_MAX_PARTITIONS, SCHEDULER_ANALYTICS_TASK_SCHEDULER, TASK_SUBMISSION_ASYNC, Tests}
4545
import org.apache.spark.network.shuffle.ExternalBlockStoreClient
4646
import org.apache.spark.rdd.{DeterministicLevel, RDD}
4747
import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfile, ResourceProfileBuilder, TaskResourceProfile, TaskResourceRequests}
@@ -981,6 +981,18 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
981981
assertDataStructuresEmpty()
982982
}
983983

984+
test("test job abort if the task reached the limitation") {
985+
conf.set(RDD_MAX_PARTITIONS.key, "3")
986+
val e = intercept[SparkException] {
987+
val rdd = sc.parallelize(1 to 10, 4)
988+
rdd.collect() === Array(1, 2)
989+
}.getMessage
990+
assert(e.contains("RDD Partitions have reached the max limitation"))
991+
// Max partition restriction should skip LIMIT operation
992+
val rdd = sc.parallelize(1 to 10, 4)
993+
assert(rdd.take(2) === Array(1, 2))
994+
}
995+
984996
private val shuffleFileLossTests = Seq(
985997
("executor process lost with shuffle service", ExecutorProcessLost("", None), true, false),
986998
("worker lost with shuffle service", ExecutorProcessLost("", Some("hostA")), true, true),

0 commit comments

Comments
 (0)