@@ -41,7 +41,7 @@ import org.apache.spark.errors.SparkCoreErrors
4141import org .apache .spark .executor .{ExecutorMetrics , TaskMetrics }
4242import org .apache .spark .internal .Logging
4343import org .apache .spark .internal .config
44- import org .apache .spark .internal .config .{JOB_GROUP_MAX_SHUFFLE_SIZE , RDD_CACHE_VISIBILITY_TRACKING_ENABLED , REMOVE_EXECUTOR_ON_FETCH_FAILURE , TASK_SUBMISSION_ASYNC }
44+ import org .apache .spark .internal .config .{JOB_GROUP_MAX_SHUFFLE_SIZE , RDD_CACHE_VISIBILITY_TRACKING_ENABLED , RDD_MAX_PARTITIONS , REMOVE_EXECUTOR_ON_FETCH_FAILURE , TASK_SUBMISSION_ASYNC }
4545import org .apache .spark .internal .config .Tests .TEST_NO_STAGE_RETRY
4646import org .apache .spark .network .shuffle .{BlockStoreClient , MergeFinalizerListener }
4747import org .apache .spark .network .shuffle .protocol .MergeStatuses
@@ -225,6 +225,8 @@ private[spark] class DAGScheduler(
225225 /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
226226 private val disallowStageRetryForTest = sc.getConf.get(TEST_NO_STAGE_RETRY )
227227
228+ private val MAX_PARTITIONS_IN_STAGE = sc.getConf.get(RDD_MAX_PARTITIONS )
229+
228230 private val removeExecutorOnFetchFailure = sc.getConf.get(REMOVE_EXECUTOR_ON_FETCH_FAILURE )
229231
230232 private val shouldMergeResourceProfiles = sc.getConf.get(config.RESOURCE_PROFILE_MERGE_CONFLICTS )
@@ -674,6 +676,10 @@ private[spark] class DAGScheduler(
674676 checkBarrierStageWithNumSlots(rdd, resourceProfile)
675677 checkBarrierStageWithRDDChainPattern(rdd, rdd.getNumPartitions)
676678 val numTasks = rdd.partitions.length
679+ if (numTasks > MAX_PARTITIONS_IN_STAGE ) {
680+ throw new SparkException (s " RDD Partitions have reached the max limitation " +
681+ s " $MAX_PARTITIONS_IN_STAGE, increase ${RDD_MAX_PARTITIONS .key} to work around. " )
682+ }
677683 val parents = getOrCreateParentStages(shuffleDeps, jobId)
678684 val id = nextStageId.getAndIncrement()
679685 val stage = new ShuffleMapStage (
@@ -809,6 +815,12 @@ private[spark] class DAGScheduler(
809815 checkBarrierStageWithRDDChainPattern(rdd, partitions.toSet.size)
810816 val parents = getOrCreateParentStages(shuffleDeps, jobId)
811817 val id = nextStageId.getAndIncrement()
818+ // Use `partitions.length` instead of `rdd.partitions.length` to
819+ // skip SELECT * table LIMIT operation
820+ if (partitions.length > MAX_PARTITIONS_IN_STAGE ) {
821+ throw new SparkException (s " RDD Partitions have reached the max limitation " +
822+ s " $MAX_PARTITIONS_IN_STAGE, increase ${RDD_MAX_PARTITIONS .key} to work around. " )
823+ }
812824 val stage = new ResultStage (id, rdd, func, partitions, parents, jobId,
813825 callSite, resourceProfile.id, resultSpillContext)
814826 stageIdToStage(id) = stage
@@ -1324,6 +1336,10 @@ private[spark] class DAGScheduler(
13241336 if (rdd.partitions.length == 0 ) {
13251337 throw SparkCoreErrors .cannotRunSubmitMapStageOnZeroPartitionRDDError()
13261338 }
1339+ if (rdd.partitions.length > MAX_PARTITIONS_IN_STAGE ) {
1340+ throw new SparkException (s " RDD Partitions have reached the max limitation " +
1341+ s " $MAX_PARTITIONS_IN_STAGE, increase ${RDD_MAX_PARTITIONS .key} to work around. " )
1342+ }
13271343
13281344 // SPARK-23626: `RDD.getPartitions()` can be slow, so we eagerly compute
13291345 // `.partitions` on every RDD in the DAG to ensure that `getPartitions()`
0 commit comments