Skip to content

Commit

Permalink
[SPARK-49839][SQL] SPJ: Skip shuffles if possible for sorts
Browse files Browse the repository at this point in the history
 ### What changes were proposed in this pull request?

This is a proposal for skipping shuffles for ORDER BY or other sort operations, if on partition columns.

    ### Why are the changes needed?

This could potentially optimize many jobs, where today all data is shuffled even if we have all the partition values and can sort them.

    ### Does this PR introduce _any_ user-facing change?

 No

    ### How was this patch tested?

 Add test in KeyGroupedPartitioningSuite

    ### Was this patch authored or co-authored using generative AI tooling?

  No
  • Loading branch information
szehon-ho committed Oct 1, 2024
1 parent 80d6651 commit 0365494
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
override def createPartitioning(numPartitions: Int): Partitioning = {
RangePartitioning(ordering, numPartitions)
}

def areAllClusterKeysMatched(expressions: Seq[Expression]): Boolean = {
expressions.length == ordering.length &&
expressions.zip(ordering).forall {
case (x, o) => x.semanticEquals(o.child)
}
}
}

/**
Expand Down Expand Up @@ -394,6 +401,9 @@ case class KeyGroupedPartitioning(
}
}

case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting =>
o.areAllClusterKeysMatched(expressions)

case _ =>
false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_SORTING_ENABLED =
buildConf("spark.sql.sources.v2.bucketing.sorting.enabled")
.doc(s"When turned on, Spark will recognize the specific distribution reported by" +
s"a V2 data source through SupportsReportPartitioning, and will try to avoid a shuffle if" +
s"possible when sorting by those columns. This config requires " +
s"${V2_BUCKETING_ENABLED.key} to be enabled.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
.version("2.4.0")
Expand Down Expand Up @@ -5756,6 +5766,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def v2BucketingAllowCompatibleTransforms: Boolean =
getConf(SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS)

def v2BucketingAllowSorting: Boolean =
getConf(SQLConf.V2_BUCKETING_SORTING_ENABLED)

def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ case class EnsureRequirements(
// Ensure that the operator's children satisfy their output distribution requirements.
var children = originalChildren.zip(requiredChildDistributions).map {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
ensureOrdering(child, distribution)
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
case (child, distribution) =>
Expand Down Expand Up @@ -281,6 +281,23 @@ case class EnsureRequirements(
}
}

private def ensureOrdering(plan: SparkPlan, distribution: Distribution) = {
(plan.outputPartitioning, distribution) match {
case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _),
d @ OrderedDistribution(ordering)) if p.satisfies(d) =>
val attrs = expressions.flatMap(_.collectLeaves()).map(_.asInstanceOf[Attribute])
val partitionOrdering: Ordering[InternalRow] = {
RowOrdering.create(ordering, attrs)
}
// Sort 'commonPartitionValues' and use this mechanism to ensure BatchScan's
// output partitions are ordered
val sorted = partitionValues.sorted(partitionOrdering)
populateCommonPartitionInfo(plan, sorted.map((_, 1)),
None, None, applyPartialClustering = false, replicatePartitions = false)
case _ => plan
}
}

/**
* Recursively reorders the join keys based on partitioning. It starts reordering the
* join keys to match HashPartitioning on either side, followed by PartitioningCollection.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,62 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
checkAnswer(df.sort("res"), Seq(Row(10.0), Row(15.5), Row(41.0)))
}

test("SPARK-48655: order by on partition keys should not introduce additional shuffle") {
val items_partitions = Array(identity("price"), identity("id"))
createTable(items, itemsColumns, items_partitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
s"(1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " +
s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp)), " +
s"(null, 'cc', 15.5, cast('2020-02-01' as timestamp)), " +
s"(3, 'cc', null, cast('2020-02-01' as timestamp))")

Seq(true, false).foreach { sortingEnabled =>
withSQLConf(SQLConf.V2_BUCKETING_SORTING_ENABLED.key -> sortingEnabled.toString) {

def verifyShuffle(cmd: String, answer: Seq[Row]): Unit = {
val df = sql(cmd)
if (sortingEnabled) {
assert(collectAllShuffles(df.queryExecution.executedPlan).isEmpty,
"should contain no shuffle when sorting by partition values")
} else {
assert(collectAllShuffles(df.queryExecution.executedPlan).size == 1,
"should contain one shuffle when optimization is disabled")
}
checkAnswer(df, answer)
}: Unit

verifyShuffle(
s"SELECT price, id FROM testcat.ns.$items ORDER BY price ASC, id ASC",
Seq(Row(null, 3), Row(10.0, 2), Row(15.5, null),
Row(15.5, 3), Row(40.0, 1), Row(41.0, 1)))

verifyShuffle(
s"SELECT price, id FROM testcat.ns.$items " +
s"ORDER BY price ASC NULLS LAST, id ASC NULLS LAST",
Seq(Row(10.0, 2), Row(15.5, 3), Row(15.5, null),
Row(40.0, 1), Row(41.0, 1), Row(null, 3)))

verifyShuffle(
s"SELECT price, id FROM testcat.ns.$items ORDER BY price DESC, id ASC",
Seq(Row(41.0, 1), Row(40.0, 1), Row(15.5, null),
Row(15.5, 3), Row(10.0, 2), Row(null, 3)))

verifyShuffle(
s"SELECT price, id FROM testcat.ns.$items ORDER BY price DESC, id DESC",
Seq(Row(41.0, 1), Row(40.0, 1), Row(15.5, 3),
Row(15.5, null), Row(10.0, 2), Row(null, 3)))

verifyShuffle(
s"SELECT price, id FROM testcat.ns.$items " +
s"ORDER BY price DESC NULLS FIRST, id DESC NULLS FIRST",
Seq(Row(null, 3), Row(41.0, 1), Row(40.0, 1),
Row(15.5, null), Row(15.5, 3), Row(10.0, 2)));
}
}
}

test("SPARK-49179: Fix v2 multi bucketed inner joins throw AssertionError") {
val cols = Array(
Column.create("id", LongType),
Expand Down

0 comments on commit 0365494

Please sign in to comment.