Skip to content

Commit d8920ae

Browse files
sunchaocloud-fan
authored andcommitted
[SPARK-35703][SQL] Relax constraint for bucket join and remove HashClusteredDistribution
### What changes were proposed in this pull request? This PR proposes the following: 1. Introducing a new trait `ShuffleSpec` which is used in `EnsureRequirements` when the node has more than one children and serves two purposes: 1) compare all children and check if they are compatible w.r.t partitioning & distribution, 2) create a new partitioning to re-shuffle the other side in case they are not compatible. 2. Remove `HashClusteredDistribution` and replace its usages with `ClusteredDistribution`. Under the new mechanism, when `EnsureRequirements` check whether shuffles are required for a plan node with >1 children, it does the following: 1. check each child of the node and see if it can satisfy the corresponding required distribution. 2. check if all children of the node are compatible with each other w.r.t their partitioning and distribution 3. if 2) fails, choose the best shuffle spec (in terms of shuffle parallelism) that can be used to repartition the other children, so that they will all have compatible partitioning ### Why are the changes needed? Spark currently only allow bucket join when the set of cluster keys from output partitioning _exactly match_ the set of join keys from the required distribution. For instance, in the following: ```sql SELECT * FROM A JOIN B ON A.c1 = B.c1 AND A.c2 = B.c2 ``` bucket join will only be triggered if both `A` and `B` are bucketed on columns `c1` and `c2`, in which case Spark will avoid shuffling both sides of the join. The above requirement, however, is too strict, as shuffle can also be avoided if both `A` and `B` are bucketed on either column `c1` or `c2`. That is, if all rows that have the same value in column `c1` are clustered into the same partition, then all rows have the same values in column `c1` and `c2` are also clustered into the same partition. In order to allow this, we'll need to change the logic of deciding whether two sides of a join operator are "co-partitioned". Currently, this is done by checking each side's output partitioning against its required distribution separately, using `Partitioning.satisfies` method. Since `HashClusteredDistribution` requires a `HashPartitioning` to have the exact match on the cluster keys, this can be done in isolation without looking at the other side's output partitioning and required distribution. However, the approach is no longer valid if we are going to relax the above constraint, as we need to compare the output partitioning and required distribution **on both sides**. For instance, in the above example, if `A` is bucketed on `c1` while `B` is bucketed on `c2`, we may need to do the following check: 1. identify where `A.c1` and `B.c2` is used in the join keys (e.g., position 0 and 1 respectively) 2. check if the positions derived from both sides exactly match each other (this becomes more complicated if a key appears in multiple positions within the join keys.) In order to achieve the above, this proposes the following: ```scala trait ShuffleSpec { // Used as a cost indicator to shuffle children def numPartitions: Int // Used to check whether this spec is compatible with `other` def isCompatibleWith(other: ShuffleSpec): Boolean // Used to create a new partitioning for the other `distribution` in case `isCompatibleWith` failed. def createPartitioning(distribution: Distribution): Partitioning } ``` A similar API is also required if we are going to support DSv2 `DataSourcePartitioning` as output partitioning in bucket join scenario, or support custom hash functions such as `HiveHash` for bucketing. With the former, even if both `A` and `B` are partitioned on columns `c1` and `c2` in the above example, they could be partitioned via different transform expressions, e.g., `A` is on `(bucket(32, c1), day(c2)` while `B` is on `(bucket(32, c1), hour(c2)`. This means we'll need to compare the partitioning from both sides of the join which makes the current approach with `Partitioning.satisfies` insufficient. The same API `isCompatibleWith` can potentially be reused for the purpose. ### Does this PR introduce _any_ user-facing change? Yes, now bucket join will be enabled for more cases as mentioned above. ### How was this patch tested? 1. Added a new test suite `ShuffleSpecSuite` 2. Added additional tests in `EnsureRequirementsSuite`. Closes #32875 from sunchao/SPARK-35703. Authored-by: Chao Sun <sunchao@apple.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 75f919e commit d8920ae

File tree

39 files changed

+4372
-3533
lines changed

39 files changed

+4372
-3533
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 175 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.plans.physical
1919

20+
import scala.collection.mutable
21+
2022
import org.apache.spark.sql.catalyst.expressions._
2123
import org.apache.spark.sql.types.{DataType, IntegerType}
2224

@@ -87,31 +89,6 @@ case class ClusteredDistribution(
8789
}
8890
}
8991

90-
/**
91-
* Represents data where tuples have been clustered according to the hash of the given
92-
* `expressions`. The hash function is defined as `HashPartitioning.partitionIdExpression`, so only
93-
* [[HashPartitioning]] can satisfy this distribution.
94-
*
95-
* This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the
96-
* number of partitions, this distribution strictly requires which partition the tuple should be in.
97-
*/
98-
case class HashClusteredDistribution(
99-
expressions: Seq[Expression],
100-
requiredNumPartitions: Option[Int] = None) extends Distribution {
101-
require(
102-
expressions != Nil,
103-
"The expressions for hash of a HashClusteredDistribution should not be Nil. " +
104-
"An AllTuples should be used to represent a distribution that only has " +
105-
"a single partition.")
106-
107-
override def createPartitioning(numPartitions: Int): Partitioning = {
108-
assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
109-
s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
110-
s"the actual number of partitions is $numPartitions.")
111-
HashPartitioning(expressions, numPartitions)
112-
}
113-
}
114-
11592
/**
11693
* Represents data where tuples have been ordered according to the `ordering`
11794
* [[Expression Expressions]]. Its requirement is defined as the following:
@@ -171,6 +148,17 @@ trait Partitioning {
171148
required.requiredNumPartitions.forall(_ == numPartitions) && satisfies0(required)
172149
}
173150

151+
/**
152+
* Creates a shuffle spec for this partitioning and its required distribution. The
153+
* spec is used in the scenario where an operator has multiple children (e.g., join), and is
154+
* used to decide whether this child is co-partitioned with others, therefore whether extra
155+
* shuffle shall be introduced.
156+
*
157+
* @param distribution the required clustered distribution for this partitioning
158+
*/
159+
def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
160+
throw new IllegalStateException(s"Unexpected partitioning: ${getClass.getSimpleName}")
161+
174162
/**
175163
* The actual method that defines whether this [[Partitioning]] can satisfy the given
176164
* [[Distribution]], after the `numPartitions` check.
@@ -202,6 +190,9 @@ case object SinglePartition extends Partitioning {
202190
case _: BroadcastDistribution => false
203191
case _ => true
204192
}
193+
194+
override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
195+
SinglePartitionShuffleSpec
205196
}
206197

207198
/**
@@ -219,17 +210,16 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
219210
override def satisfies0(required: Distribution): Boolean = {
220211
super.satisfies0(required) || {
221212
required match {
222-
case h: HashClusteredDistribution =>
223-
expressions.length == h.expressions.length && expressions.zip(h.expressions).forall {
224-
case (l, r) => l.semanticEquals(r)
225-
}
226213
case ClusteredDistribution(requiredClustering, _) =>
227214
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
228215
case _ => false
229216
}
230217
}
231218
}
232219

220+
override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
221+
HashShuffleSpec(this, distribution)
222+
233223
/**
234224
* Returns an expression that will produce a valid partition ID(i.e. non-negative and is less
235225
* than numPartitions) based on hashing expressions.
@@ -288,6 +278,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
288278
}
289279
}
290280

281+
override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
282+
RangeShuffleSpec(this.numPartitions, distribution)
283+
291284
override protected def withNewChildrenInternal(
292285
newChildren: IndexedSeq[Expression]): RangePartitioning =
293286
copy(ordering = newChildren.asInstanceOf[Seq[SortOrder]])
@@ -330,6 +323,11 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
330323
override def satisfies0(required: Distribution): Boolean =
331324
partitionings.exists(_.satisfies(required))
332325

326+
override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = {
327+
val filtered = partitionings.filter(_.satisfies(distribution))
328+
ShuffleSpecCollection(filtered.map(_.createShuffleSpec(distribution)))
329+
}
330+
333331
override def toString: String = {
334332
partitionings.map(_.toString).mkString("(", " or ", ")")
335333
}
@@ -352,3 +350,151 @@ case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning {
352350
case _ => false
353351
}
354352
}
353+
354+
/**
355+
* This is used in the scenario where an operator has multiple children (e.g., join) and one or more
356+
* of which have their own requirement regarding whether its data can be considered as
357+
* co-partitioned from others. This offers APIs for:
358+
*
359+
* - Comparing with specs from other children of the operator and check if they are compatible.
360+
* When two specs are compatible, we can say their data are co-partitioned, and Spark will
361+
* potentially be able to eliminate shuffle if necessary.
362+
* - Creating a partitioning that can be used to re-partition another child, so that to make it
363+
* having a compatible partitioning as this node.
364+
*/
365+
trait ShuffleSpec {
366+
/**
367+
* Returns the number of partitions of this shuffle spec
368+
*/
369+
def numPartitions: Int
370+
371+
/**
372+
* Returns true iff this spec is compatible with the provided shuffle spec.
373+
*
374+
* A true return value means that the data partitioning from this spec can be seen as
375+
* co-partitioned with the `other`, and therefore no shuffle is required when joining the two
376+
* sides.
377+
*/
378+
def isCompatibleWith(other: ShuffleSpec): Boolean
379+
380+
/**
381+
* Whether this shuffle spec can be used to create partitionings for the other children.
382+
*/
383+
def canCreatePartitioning: Boolean = false
384+
385+
/**
386+
* Creates a partitioning that can be used to re-partition the other side with the given
387+
* clustering expressions.
388+
*
389+
* This will only be called when:
390+
* - [[canCreatePartitioning]] returns true.
391+
* - [[isCompatibleWith]] returns false on the side where the `clustering` is from.
392+
*/
393+
def createPartitioning(clustering: Seq[Expression]): Partitioning =
394+
throw new UnsupportedOperationException("Operation unsupported for " +
395+
s"${getClass.getCanonicalName}")
396+
}
397+
398+
case object SinglePartitionShuffleSpec extends ShuffleSpec {
399+
override def isCompatibleWith(other: ShuffleSpec): Boolean = {
400+
other.numPartitions == 1
401+
}
402+
403+
override def canCreatePartitioning: Boolean = true
404+
405+
override def createPartitioning(clustering: Seq[Expression]): Partitioning =
406+
SinglePartition
407+
408+
override def numPartitions: Int = 1
409+
}
410+
411+
case class RangeShuffleSpec(
412+
numPartitions: Int,
413+
distribution: ClusteredDistribution) extends ShuffleSpec {
414+
415+
override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
416+
case SinglePartitionShuffleSpec => numPartitions == 1
417+
case ShuffleSpecCollection(specs) => specs.exists(isCompatibleWith)
418+
// `RangePartitioning` is not compatible with any other partitioning since it can't guarantee
419+
// data are co-partitioned for all the children, as range boundaries are randomly sampled.
420+
case _ => false
421+
}
422+
}
423+
424+
case class HashShuffleSpec(
425+
partitioning: HashPartitioning,
426+
distribution: ClusteredDistribution) extends ShuffleSpec {
427+
lazy val hashKeyPositions: Seq[mutable.BitSet] =
428+
createHashKeyPositions(distribution.clustering, partitioning.expressions)
429+
430+
override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
431+
case SinglePartitionShuffleSpec =>
432+
partitioning.numPartitions == 1
433+
case otherHashSpec @ HashShuffleSpec(otherPartitioning, otherDistribution) =>
434+
// we need to check:
435+
// 1. both distributions have the same number of clustering expressions
436+
// 2. both partitioning have the same number of partitions
437+
// 3. both partitioning have the same number of expressions
438+
// 4. each pair of expression from both has overlapping positions in their
439+
// corresponding distributions.
440+
distribution.clustering.length == otherDistribution.clustering.length &&
441+
partitioning.numPartitions == otherPartitioning.numPartitions &&
442+
partitioning.expressions.length == otherPartitioning.expressions.length && {
443+
val otherHashKeyPositions = otherHashSpec.hashKeyPositions
444+
hashKeyPositions.zip(otherHashKeyPositions).forall { case (left, right) =>
445+
left.intersect(right).nonEmpty
446+
}
447+
}
448+
case ShuffleSpecCollection(specs) =>
449+
specs.exists(isCompatibleWith)
450+
case _ =>
451+
false
452+
}
453+
454+
override def canCreatePartitioning: Boolean = true
455+
456+
override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
457+
val exprs = hashKeyPositions.map(v => clustering(v.head))
458+
HashPartitioning(exprs, partitioning.numPartitions)
459+
}
460+
461+
override def numPartitions: Int = partitioning.numPartitions
462+
463+
/**
464+
* Returns a sequence where each element is a set of positions of the key in `hashKeys` to its
465+
* positions in `requiredClusterKeys`. For instance, if `requiredClusterKeys` is [a, b, b] and
466+
* `hashKeys` is [a, b], the result will be [(0), (1, 2)].
467+
*/
468+
private def createHashKeyPositions(
469+
requiredClusterKeys: Seq[Expression],
470+
hashKeys: Seq[Expression]): Seq[mutable.BitSet] = {
471+
val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet]
472+
requiredClusterKeys.zipWithIndex.foreach { case (distKey, distKeyPos) =>
473+
distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos)
474+
}
475+
476+
hashKeys.map(k => distKeyToPos(k.canonicalized))
477+
}
478+
}
479+
480+
case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec {
481+
override def isCompatibleWith(other: ShuffleSpec): Boolean = {
482+
specs.exists(_.isCompatibleWith(other))
483+
}
484+
485+
override def canCreatePartitioning: Boolean =
486+
specs.forall(_.canCreatePartitioning)
487+
488+
override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
489+
// as we only consider # of partitions as the cost now, it doesn't matter which one we choose
490+
// since they should all have the same # of partitions.
491+
require(specs.map(_.numPartitions).toSet.size == 1, "expected all specs in the collection " +
492+
"to have the same number of partitions")
493+
specs.head.createPartitioning(clustering)
494+
}
495+
496+
override def numPartitions: Int = {
497+
require(specs.nonEmpty, "expected specs to be non-empty")
498+
specs.head.numPartitions
499+
}
500+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,6 @@ class DistributionSuite extends SparkFunSuite {
133133
ClusteredDistribution(Seq($"a", $"b", $"c")),
134134
true)
135135

136-
checkSatisfied(
137-
SinglePartition,
138-
HashClusteredDistribution(Seq($"a", $"b", $"c")),
139-
true)
140-
141136
checkSatisfied(
142137
SinglePartition,
143138
OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)),
@@ -172,23 +167,6 @@ class DistributionSuite extends SparkFunSuite {
172167
ClusteredDistribution(Seq($"d", $"e")),
173168
false)
174169

175-
// HashPartitioning can satisfy HashClusteredDistribution iff its hash expressions are exactly
176-
// same with the required hash clustering expressions.
177-
checkSatisfied(
178-
HashPartitioning(Seq($"a", $"b", $"c"), 10),
179-
HashClusteredDistribution(Seq($"a", $"b", $"c")),
180-
true)
181-
182-
checkSatisfied(
183-
HashPartitioning(Seq($"c", $"b", $"a"), 10),
184-
HashClusteredDistribution(Seq($"a", $"b", $"c")),
185-
false)
186-
187-
checkSatisfied(
188-
HashPartitioning(Seq($"a", $"b"), 10),
189-
HashClusteredDistribution(Seq($"a", $"b", $"c")),
190-
false)
191-
192170
// HashPartitioning cannot satisfy OrderedDistribution
193171
checkSatisfied(
194172
HashPartitioning(Seq($"a", $"b", $"c"), 10),
@@ -269,12 +247,6 @@ class DistributionSuite extends SparkFunSuite {
269247
RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10),
270248
ClusteredDistribution(Seq($"c", $"d")),
271249
false)
272-
273-
// RangePartitioning cannot satisfy HashClusteredDistribution
274-
checkSatisfied(
275-
RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10),
276-
HashClusteredDistribution(Seq($"a", $"b", $"c")),
277-
false)
278250
}
279251

280252
test("Partitioning.numPartitions must match Distribution.requiredNumPartitions to satisfy it") {
@@ -283,21 +255,11 @@ class DistributionSuite extends SparkFunSuite {
283255
ClusteredDistribution(Seq($"a", $"b", $"c"), Some(10)),
284256
false)
285257

286-
checkSatisfied(
287-
SinglePartition,
288-
HashClusteredDistribution(Seq($"a", $"b", $"c"), Some(10)),
289-
false)
290-
291258
checkSatisfied(
292259
HashPartitioning(Seq($"a", $"b", $"c"), 10),
293260
ClusteredDistribution(Seq($"a", $"b", $"c"), Some(5)),
294261
false)
295262

296-
checkSatisfied(
297-
HashPartitioning(Seq($"a", $"b", $"c"), 10),
298-
HashClusteredDistribution(Seq($"a", $"b", $"c"), Some(5)),
299-
false)
300-
301263
checkSatisfied(
302264
RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10),
303265
ClusteredDistribution(Seq($"a", $"b", $"c"), Some(5)),

0 commit comments

Comments
 (0)