diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ef974dc176e51..0e6b69ae274e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -509,7 +509,7 @@ object SQLConf { "'spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes'") .version("3.0.0") .intConf - .checkValue(_ > 0, "The skew factor must be positive.") + .checkValue(_ >= 0, "The skew factor cannot be negative.") .createWithDefault(5) val SKEW_JOIN_SKEWED_PARTITION_THRESHOLD = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index f003830f3a45d..45ba2202d83d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, QueryExecuti import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.noop.NoopDataSource import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, REPARTITION_WITH_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, REPARTITION, REPARTITION_WITH_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate import org.apache.spark.sql.functions._ @@ -1319,8 +1319,15 @@ class AdaptiveQueryExecSuite } test("SPARK-33551: Do not use custom shuffle reader for repartition") { + def hasRepartitionShuffle(plan: SparkPlan): Boolean = { + find(plan) { + case s: ShuffleExchangeLike => + s.shuffleOrigin == REPARTITION || s.shuffleOrigin == REPARTITION_WITH_NUM + case _ => false + }.isDefined + } + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", SQLConf.SHUFFLE_PARTITIONS.key -> "5") { val df = sql( """ @@ -1331,45 +1338,97 @@ class AdaptiveQueryExecSuite |ON value = b """.stripMargin) - // Repartition with no partition num specified. - val dfRepartition = df.repartition('b) - dfRepartition.collect() - val plan = dfRepartition.queryExecution.executedPlan - val bhj = findTopLevelBroadcastHashJoin(plan) - assert(bhj.length == 1) - checkNumLocalShuffleReaders(plan, 1) - // Probe side is coalesced. - val customReader = bhj.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]) - assert(customReader.isDefined) - assert(customReader.get.asInstanceOf[CustomShuffleReaderExec].hasCoalescedPartition) - - // Repartition with partition default num specified. - val dfRepartitionWithNum = df.repartition(5, 'b) - dfRepartitionWithNum.collect() - val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan - val bhjWithNum = findTopLevelBroadcastHashJoin(planWithNum) - assert(bhjWithNum.length == 1) - checkNumLocalShuffleReaders(planWithNum, 1) - // Probe side is not coalesced. - assert(bhjWithNum.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]).isEmpty) - - // Repartition with partition non-default num specified. - val dfRepartitionWithNum2 = df.repartition(3, 'b) - dfRepartitionWithNum2.collect() - val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan - val bhjWithNum2 = findTopLevelBroadcastHashJoin(planWithNum2) - assert(bhjWithNum2.length == 1) - // The top shuffle from repartition is not optimized out, and this is the only shuffle that - // does not have local shuffle reader. - val repartition = find(planWithNum2) { - case s: ShuffleExchangeLike => s.shuffleOrigin == REPARTITION_WITH_NUM - case _ => false + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + // Repartition with no partition num specified. + val dfRepartition = df.repartition('b) + dfRepartition.collect() + val plan = dfRepartition.queryExecution.executedPlan + // The top shuffle from repartition is optimized out. + assert(!hasRepartitionShuffle(plan)) + val bhj = findTopLevelBroadcastHashJoin(plan) + assert(bhj.length == 1) + checkNumLocalShuffleReaders(plan, 1) + // Probe side is coalesced. + val customReader = bhj.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]) + assert(customReader.isDefined) + assert(customReader.get.asInstanceOf[CustomShuffleReaderExec].hasCoalescedPartition) + + // Repartition with partition default num specified. + val dfRepartitionWithNum = df.repartition(5, 'b) + dfRepartitionWithNum.collect() + val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan + // The top shuffle from repartition is optimized out. + assert(!hasRepartitionShuffle(planWithNum)) + val bhjWithNum = findTopLevelBroadcastHashJoin(planWithNum) + assert(bhjWithNum.length == 1) + checkNumLocalShuffleReaders(planWithNum, 1) + // Probe side is not coalesced. + assert(bhjWithNum.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]).isEmpty) + + // Repartition with partition non-default num specified. + val dfRepartitionWithNum2 = df.repartition(3, 'b) + dfRepartitionWithNum2.collect() + val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan + // The top shuffle from repartition is not optimized out, and this is the only shuffle that + // does not have local shuffle reader. + assert(hasRepartitionShuffle(planWithNum2)) + val bhjWithNum2 = findTopLevelBroadcastHashJoin(planWithNum2) + assert(bhjWithNum2.length == 1) + checkNumLocalShuffleReaders(planWithNum2, 1) + val customReader2 = bhjWithNum2.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]) + assert(customReader2.isDefined) + assert(customReader2.get.asInstanceOf[CustomShuffleReaderExec].isLocalReader) + } + + // Force skew join + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SKEW_JOIN_ENABLED.key -> "true", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "1", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") { + // Repartition with no partition num specified. + val dfRepartition = df.repartition('b) + dfRepartition.collect() + val plan = dfRepartition.queryExecution.executedPlan + // The top shuffle from repartition is optimized out. + assert(!hasRepartitionShuffle(plan)) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.length == 1) + // No skew join due to the repartition. + assert(!smj.head.isSkewJoin) + // Both sides are coalesced. + val customReaders = collect(smj.head) { + case c: CustomShuffleReaderExec if c.hasCoalescedPartition => c + } + assert(customReaders.length == 2) + + // Repartition with default partition num specified. + val dfRepartitionWithNum = df.repartition(5, 'b) + dfRepartitionWithNum.collect() + val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan + // The top shuffle from repartition is optimized out. + assert(!hasRepartitionShuffle(planWithNum)) + val smjWithNum = findTopLevelSortMergeJoin(planWithNum) + assert(smjWithNum.length == 1) + // No skew join due to the repartition. + assert(!smjWithNum.head.isSkewJoin) + // No coalesce due to the num in repartition. + val customReadersWithNum = collect(smjWithNum.head) { + case c: CustomShuffleReaderExec if c.hasCoalescedPartition => c + } + assert(customReadersWithNum.isEmpty) + + // Repartition with default non-partition num specified. + val dfRepartitionWithNum2 = df.repartition(3, 'b) + dfRepartitionWithNum2.collect() + val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan + // The top shuffle from repartition is not optimized out. + assert(hasRepartitionShuffle(planWithNum2)) + val smjWithNum2 = findTopLevelSortMergeJoin(planWithNum2) + assert(smjWithNum2.length == 1) + // Skew join can apply as the repartition is not optimized out. + assert(smjWithNum2.head.isSkewJoin) } - assert(repartition.isDefined) - checkNumLocalShuffleReaders(planWithNum2, 1) - val customReader2 = bhjWithNum2.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]) - assert(customReader2.isDefined) - assert(customReader2.get.asInstanceOf[CustomShuffleReaderExec].isLocalReader) } } }