Skip to content

[SPARK-38237][SQL][SS] Allow ClusteredDistribution to require full clustering keys #35574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,14 @@ case object AllTuples extends Distribution {
/**
* Represents data where tuples that share the same values for the `clustering`
* [[Expression Expressions]] will be co-located in the same partition.
*
* @param requireAllClusterKeys When true, `Partitioning` which satisfies this distribution,
* must match all `clustering` expressions in the same ordering.
*/
case class ClusteredDistribution(
clustering: Seq[Expression],
requireAllClusterKeys: Boolean = SQLConf.get.getConf(
Copy link
Contributor

@cloud-fan cloud-fan Feb 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it's less breaking to put new parameter at the end, so that some caller-side code can remain unchanged.

Copy link
Contributor Author

@c21 c21 Feb 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - I agree for the point of caller-side code unchanged. I guess it's just feeling more coherent for others to read and understand code, when putting clustering and requireAllClusterKeys together. This was raised by #35574 (comment) by @HeartSaVioR as well. I am curious would adding the field in the middle here break other external library depending on Spark? I guess otherwise reviewers already paid the cost of time to review this PR, so not sure how important to change the caller-side code back. Just want to understand more here and I am open to change back.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK maybe it's not a big deal

SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION),
requiredNumPartitions: Option[Int] = None) extends Distribution {
require(
clustering != Nil,
Expand All @@ -88,6 +93,19 @@ case class ClusteredDistribution(
s"the actual number of partitions is $numPartitions.")
HashPartitioning(clustering, numPartitions)
}

/**
* Checks if `expressions` match all `clustering` expressions in the same ordering.
*
* `Partitioning` should call this to check its expressions when `requireAllClusterKeys`
* is set to true.
*/
def areAllClusterKeysMatched(expressions: Seq[Expression]): Boolean = {
expressions.length == clustering.length &&
expressions.zip(clustering).forall {
Copy link
Member

@sunchao sunchao Feb 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For aggregate or window, I'm not sure whether we have any reordering mechanism similar to join. If not, this could be very limited? for instance if users have group by x, y, z while the data distribution is y, z, x, then they have to rewrite the queries to match the distribution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This goes back to same discussion here - #35574 (comment) . I am more inclined to require same ordering. But if quorum of folks here think we should relax, then I am also fine. cc @cloud-fan.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I forgot this is discussed already (and I participated in the thread too... 😓 ). I'm fine with more strict ordering to start with.

case (l, r) => l.semanticEquals(r)
}
}
}

/**
Expand Down Expand Up @@ -261,8 +279,14 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
expressions.length == h.expressions.length && expressions.zip(h.expressions).forall {
case (l, r) => l.semanticEquals(r)
}
case ClusteredDistribution(requiredClustering, _) =>
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
if (requireAllClusterKeys) {
// Checks `HashPartitioning` is partitioned on exactly same clustering keys of
// `ClusteredDistribution`.
c.areAllClusterKeysMatched(expressions)
} else {
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}
case _ => false
}
}
Expand Down Expand Up @@ -322,8 +346,15 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
// `RangePartitioning(a, b, c)` satisfies `OrderedDistribution(a, b)`.
val minSize = Seq(requiredOrdering.size, ordering.size).min
requiredOrdering.take(minSize) == ordering.take(minSize)
case ClusteredDistribution(requiredClustering, _) =>
ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x)))
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
val expressions = ordering.map(_.child)
if (requireAllClusterKeys) {
// Checks `RangePartitioning` is partitioned on exactly same clustering keys of
// `ClusteredDistribution`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is less strict than the previous HashClusteredDistribution, but looks fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's less strict. Previously we don't allow RangePartitioning to satisfy HashClusteredDistribution. I think it should be fine too.

c.areAllClusterKeysMatched(expressions)
} else {
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}
case _ => false
}
}
Expand Down Expand Up @@ -524,10 +555,7 @@ case class HashShuffleSpec(
// will add shuffles with the default partitioning of `ClusteredDistribution`, which uses all
// the join keys.
if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) {
partitioning.expressions.length == distribution.clustering.length &&
partitioning.expressions.zip(distribution.clustering).forall {
case (l, r) => l.semanticEquals(r)
}
distribution.areAllClusterKeysMatched(partitioning.expressions)
} else {
true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,18 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION =
buildConf("spark.sql.requireAllClusterKeysForDistribution")
.internal()
.doc("When true, the planner requires all the clustering keys as the partition keys " +
"(with same ordering) of the children, to eliminate the shuffle for the operator that " +
"requires its children be clustered distributed, such as AGGREGATE and WINDOW node. " +
"This is to avoid data skews which can lead to significant performance regression if " +
"shuffle is eliminated.")
.version("3.3.0")
.booleanConf
.createWithDefault(false)

val RADIX_SORT_ENABLED = buildConf("spark.sql.sort.enableRadixSort")
.internal()
.doc("When true, enable use of radix sort when possible. Radix sort is much faster but " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,24 @@ class DistributionSuite extends SparkFunSuite {
ClusteredDistribution(Seq($"d", $"e")),
false)

// When ClusteredDistribution.requireAllClusterKeys is set to true,
// HashPartitioning can only satisfy ClusteredDistribution iff its hash expressions are
// exactly same as the required clustering expressions.
checkSatisfied(
HashPartitioning(Seq($"a", $"b", $"c"), 10),
ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true),
true)

checkSatisfied(
HashPartitioning(Seq($"b", $"c"), 10),
ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true),
false)

checkSatisfied(
HashPartitioning(Seq($"b", $"a", $"c"), 10),
ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true),
false)

// HashPartitioning cannot satisfy OrderedDistribution
checkSatisfied(
HashPartitioning(Seq($"a", $"b", $"c"), 10),
Expand Down Expand Up @@ -249,22 +267,40 @@ class DistributionSuite extends SparkFunSuite {
RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10),
ClusteredDistribution(Seq($"c", $"d")),
false)

// When ClusteredDistribution.requireAllClusterKeys is set to true,
// RangePartitioning can only satisfy ClusteredDistribution iff its ordering expressions are
// exactly same as the required clustering expressions.
checkSatisfied(
RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10),
ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true),
true)

checkSatisfied(
RangePartitioning(Seq($"a".asc, $"b".asc), 10),
ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true),
false)

checkSatisfied(
RangePartitioning(Seq($"b".asc, $"a".asc, $"c".asc), 10),
ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true),
false)
}

test("Partitioning.numPartitions must match Distribution.requiredNumPartitions to satisfy it") {
checkSatisfied(
SinglePartition,
ClusteredDistribution(Seq($"a", $"b", $"c"), Some(10)),
ClusteredDistribution(Seq($"a", $"b", $"c"), requiredNumPartitions = Some(10)),
false)

checkSatisfied(
HashPartitioning(Seq($"a", $"b", $"c"), 10),
ClusteredDistribution(Seq($"a", $"b", $"c"), Some(5)),
ClusteredDistribution(Seq($"a", $"b", $"c"), requiredNumPartitions = Some(5)),
false)

checkSatisfied(
RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10),
ClusteredDistribution(Seq($"a", $"b", $"c"), Some(5)),
ClusteredDistribution(Seq($"a", $"b", $"c"), requiredNumPartitions = Some(5)),
false)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ object AQEUtils {
} else {
None
}
Some(ClusteredDistribution(h.expressions, numPartitions))
Some(ClusteredDistribution(h.expressions, requiredNumPartitions = numPartitions))
case f: FilterExec => getRequiredDistribution(f.child)
case s: SortExec if !s.global => getRequiredDistribution(s.child)
case c: CollectMetricsExec => getRequiredDistribution(c.child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ case class FlatMapGroupsWithStateExec(
// NOTE: Please read through the NOTE on the classdoc of StatefulOpClusteredDistribution
// before making any changes.
// TODO(SPARK-38204)
ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) ::
ClusteredDistribution(initialStateGroupAttrs, stateInfo.map(_.numPartitions)) ::
ClusteredDistribution(
groupingAttributes, requiredNumPartitions = stateInfo.map(_.numPartitions)) ::
ClusteredDistribution(
initialStateGroupAttrs, requiredNumPartitions = stateInfo.map(_.numPartitions)) ::
Nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,11 @@ abstract class StreamExecution(
// Disable cost-based join optimization as we do not want stateful operations
// to be rearranged
sparkSessionForStream.conf.set(SQLConf.CBO_ENABLED.key, "false")
// Disable any config affecting the required child distribution of stateful operators.
// Please read through the NOTE on the classdoc of StatefulOpClusteredDistribution for
// details.
sparkSessionForStream.conf.set(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key,
"false")

updateStatusMessage("Initializing sources")
// force initialization of the logical plan so that the sources can be created
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,8 @@ case class StateStoreRestoreExec(
if (keyExpressions.isEmpty) {
AllTuples :: Nil
} else {
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
ClusteredDistribution(keyExpressions,
requiredNumPartitions = stateInfo.map(_.numPartitions)) :: Nil
}
}

Expand Down Expand Up @@ -502,7 +503,8 @@ case class StateStoreSaveExec(
if (keyExpressions.isEmpty) {
AllTuples :: Nil
} else {
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
ClusteredDistribution(keyExpressions,
requiredNumPartitions = stateInfo.map(_.numPartitions)) :: Nil
}
}

Expand Down Expand Up @@ -582,7 +584,8 @@ case class SessionWindowStateStoreRestoreExec(
// NOTE: Please read through the NOTE on the classdoc of StatefulOpClusteredDistribution
// before making any changes.
// TODO(SPARK-38204)
ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil
ClusteredDistribution(keyWithoutSessionExpressions,
requiredNumPartitions = stateInfo.map(_.numPartitions)) :: Nil
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
Expand Down Expand Up @@ -696,7 +699,8 @@ case class SessionWindowStateStoreSaveExec(
// NOTE: Please read through the NOTE on the classdoc of StatefulOpClusteredDistribution
// before making any changes.
// TODO(SPARK-38204)
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
ClusteredDistribution(keyExpressions,
requiredNumPartitions = stateInfo.map(_.numPartitions)) :: Nil
}

override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
Expand Down Expand Up @@ -757,7 +761,8 @@ case class StreamingDeduplicateExec(
// NOTE: Please read through the NOTE on the classdoc of StatefulOpClusteredDistribution
// before making any changes.
// TODO(SPARK-38204)
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
ClusteredDistribution(keyExpressions,
requiredNumPartitions = stateInfo.map(_.numPartitions)) :: Nil
}

override protected def doExecute(): RDD[InternalRow] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ package org.apache.spark.sql
import org.scalatest.matchers.must.Matchers.the

import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.optimizer.TransposeWindow
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, Exchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, Window}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -1071,4 +1074,48 @@ class DataFrameWindowFunctionsSuite extends QueryTest
Row("a", 1, "x", "x"),
Row("b", 0, null, null)))
}

test("SPARK-38237: require all cluster keys for child required distribution for window query") {
def partitionExpressionsColumns(expressions: Seq[Expression]): Seq[String] = {
expressions.flatMap {
case ref: AttributeReference => Some(ref.name)
}
}

def isShuffleExecByRequirement(
plan: ShuffleExchangeExec,
desiredClusterColumns: Seq[String]): Boolean = plan match {
case ShuffleExchangeExec(op: HashPartitioning, _, ENSURE_REQUIREMENTS) =>
partitionExpressionsColumns(op.expressions) === desiredClusterColumns
case _ => false
}

val df = Seq(("a", 1, 1), ("a", 2, 2), ("b", 1, 3), ("b", 1, 4)).toDF("key1", "key2", "value")
val windowSpec = Window.partitionBy("key1", "key2").orderBy("value")

withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key -> "true") {

val windowed = df
// repartition by subset of window partitionBy keys which satisfies ClusteredDistribution
.repartition($"key1")
.select(
lead($"key1", 1).over(windowSpec),
lead($"value", 1).over(windowSpec))

checkAnswer(windowed, Seq(Row("b", 4), Row(null, null), Row(null, null), Row(null, null)))

val shuffleByRequirement = windowed.queryExecution.executedPlan.find {
case w: WindowExec =>
w.child.find {
case s: ShuffleExchangeExec => isShuffleExecByRequirement(s, Seq("key1", "key2"))
case _ => false
}.nonEmpty
case _ => false
}

assert(shuffleByRequirement.nonEmpty, "Can't find desired shuffle node from the query plan")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
}

test("EnsureRequirements should respect ClusteredDistribution's num partitioning") {
val distribution = ClusteredDistribution(Literal(1) :: Nil, Some(13))
val distribution = ClusteredDistribution(Literal(1) :: Nil, requiredNumPartitions = Some(13))
// Number of partitions differ
val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 13)
val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5)
Expand Down