Skip to content

Commit c81c039

Browse files
committed
[SPARK-47094][SQL] SPJ : Dynamically rebalance number of buckets when they are not equal
### What changes were proposed in this pull request? -- Allow SPJ between 'compatible' bucket funtions -- Add a mechanism to define 'reducible' functions, one function whose output can be 'reduced' to another for all inputs. ### Why are the changes needed? -- SPJ currently applies only if the partition transform expressions on both sides are identifical. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added new tests in KeyGroupedPartitioningSuite ### Was this patch authored or co-authored using generative AI tooling? No
1 parent 76c4fd5 commit c81c039

File tree

9 files changed

+531
-22
lines changed

9 files changed

+531
-22
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package org.apache.spark.sql.connector.catalog.functions;
2+
3+
import org.apache.spark.annotation.Evolving;
4+
5+
/**
6+
* A 'reducer' for output of user-defined functions.
7+
*
8+
* A user_defined function f_source(x) is 'reducible' on another user_defined function f_target(x),
9+
* if there exists a 'reducer' r(x) such that r(f_source(x)) = f_target(x) for all input x.
10+
* @param <T> function output type
11+
* @since 4.0.0
12+
*/
13+
@Evolving
14+
public interface Reducer<T> {
15+
T reduce(T arg1);
16+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package org.apache.spark.sql.connector.catalog.functions;
2+
3+
import org.apache.spark.annotation.Evolving;
4+
import scala.Option;
5+
6+
/**
7+
* Base class for user-defined functions that can be 'reduced' on another function.
8+
*
9+
* A function f_source(x) is 'reducible' on another function f_target(x) if
10+
* there exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x.
11+
*
12+
* @since 4.0.0
13+
*/
14+
@Evolving
15+
public interface ReducibleFunction<T, A> extends ScalarFunction<T> {
16+
17+
/**
18+
* If this function is 'reducible' on another function, return the {@link Reducer} function.
19+
* @param other other function
20+
* @param thisArgument argument for this function instance
21+
* @param otherArgument argument for other function instance
22+
* @return a reduction function if it is reducible, none if not
23+
*/
24+
Option<Reducer<A>> reducer(ReducibleFunction<?, ?> other, Option<?> thisArgument, Option<?> otherArgument);
25+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.connector.catalog.functions.BoundFunction
20+
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ReducibleFunction}
2121
import org.apache.spark.sql.types.DataType
2222

2323
/**
@@ -54,6 +54,31 @@ case class TransformExpression(
5454
false
5555
}
5656

57+
/**
58+
* Whether this [[TransformExpression]]'s function is compatible with the `other`
59+
* [[TransformExpression]]'s function.
60+
*
61+
* This is true if both are instances of [[ReducibleFunction]] and there exists a [[Reducer]] r(x)
62+
* such that r(t1(x)) = t2(x), or r(t2(x)) = t1(x), for all input x.
63+
*
64+
* @param other the transform expression to compare to
65+
* @return true if compatible, false if not
66+
*/
67+
def isCompatible(other: TransformExpression): Boolean = {
68+
if (isSameFunction(other)) {
69+
true
70+
} else {
71+
(function, other.function) match {
72+
case (f: ReducibleFunction[Any, Any] @unchecked,
73+
o: ReducibleFunction[Any, Any] @unchecked) =>
74+
val reducer = f.reducer(o, numBucketsOpt, other.numBucketsOpt)
75+
val otherReducer = o.reducer(f, other.numBucketsOpt, numBucketsOpt)
76+
reducer.isDefined || otherReducer.isDefined
77+
case _ => false
78+
}
79+
}
80+
}
81+
5782
override def dataType: DataType = function.resultType()
5883

5984
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
27+
import org.apache.spark.sql.connector.catalog.functions.{Reducer, ReducibleFunction}
2728
import org.apache.spark.sql.internal.SQLConf
2829
import org.apache.spark.sql.types.{DataType, IntegerType}
2930

@@ -635,6 +636,22 @@ trait ShuffleSpec {
635636
*/
636637
def createPartitioning(clustering: Seq[Expression]): Partitioning =
637638
throw SparkUnsupportedOperationException()
639+
640+
/**
641+
* Return a set of [[Reducer]] for the partition expressions of this shuffle spec,
642+
* on the partition expressions of another shuffle spec.
643+
* <p>
644+
* A [[Reducer]] exists for a partition expression function of this shuffle spec if it is
645+
* 'reducible' on the corresponding partition expression function of the other shuffle spec.
646+
* <p>
647+
* If a value is returned, there must be one Option[[Reducer]] per partition expression.
648+
* A None value in the set indicates that the particular partition expression is not reducible
649+
* on the corresponding expression on the other shuffle spec.
650+
* <p>
651+
* Returning none also indicates that none of the partition expressions can be reduced on the
652+
* corresponding expression on the other shuffle spec.
653+
*/
654+
def reducers(spec: ShuffleSpec): Option[Seq[Option[Reducer[Any]]]] = None
638655
}
639656

640657
case object SinglePartitionShuffleSpec extends ShuffleSpec {
@@ -829,20 +846,60 @@ case class KeyGroupedShuffleSpec(
829846
}
830847
}
831848

849+
override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
850+
// Only support partition expressions are AttributeReference for now
851+
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])
852+
853+
override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
854+
KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues)
855+
}
856+
857+
override def reducers(other: ShuffleSpec): Option[Seq[Option[Reducer[Any]]]] = {
858+
other match {
859+
case otherSpec: KeyGroupedShuffleSpec =>
860+
val results = partitioning.expressions.zip(otherSpec.partitioning.expressions).map {
861+
case (e1: TransformExpression, e2: TransformExpression)
862+
if e1.function.isInstanceOf[ReducibleFunction[Any, Any]@unchecked]
863+
&& e2.function.isInstanceOf[ReducibleFunction[Any, Any]@unchecked] =>
864+
e1.function.asInstanceOf[ReducibleFunction[Any, Any]].reducer(
865+
e2.function.asInstanceOf[ReducibleFunction[Any, Any]],
866+
e1.numBucketsOpt.map(a => a.asInstanceOf[Any]),
867+
e2.numBucketsOpt.map(a => a.asInstanceOf[Any]))
868+
case (_, _) => None
869+
}
870+
871+
// optimize to not return a value, if none of the partition expressions need reducing
872+
if (results.forall(p => p.isEmpty)) None else Some(results)
873+
case _ => None
874+
}
875+
}
876+
832877
private def isExpressionCompatible(left: Expression, right: Expression): Boolean =
833878
(left, right) match {
834879
case (_: LeafExpression, _: LeafExpression) => true
835880
case (left: TransformExpression, right: TransformExpression) =>
836-
left.isSameFunction(right)
881+
if (SQLConf.get.v2BucketingPushPartValuesEnabled &&
882+
!SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
883+
SQLConf.get.v2BucketingAllowCompatibleTransforms) {
884+
left.isCompatible(right)
885+
} else {
886+
left.isSameFunction(right)
887+
}
837888
case _ => false
838889
}
890+
}
839891

840-
override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
841-
// Only support partition expressions are AttributeReference for now
842-
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])
843-
844-
override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
845-
KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues)
892+
object KeyGroupedShuffleSpec {
893+
def reducePartitionValue(row: InternalRow,
894+
expressions: Seq[Expression],
895+
reducers: Seq[Option[Reducer[Any]]]):
896+
InternalRowComparableWrapper = {
897+
val partitionVals = row.toSeq(expressions.map(_.dataType))
898+
val reducedRow = partitionVals.zip(reducers).map{
899+
case (v, Some(reducer)) => reducer.reduce(v)
900+
case (v, _) => v
901+
}.toArray
902+
InternalRowComparableWrapper(new GenericInternalRow(reducedRow), expressions)
846903
}
847904
}
848905

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1537,6 +1537,18 @@ object SQLConf {
15371537
.booleanConf
15381538
.createWithDefault(false)
15391539

1540+
val V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS =
1541+
buildConf("spark.sql.sources.v2.bucketing.allow.enabled")
1542+
.doc("Whether to allow storage-partition join in the case where the partition transforms" +
1543+
"are compatible but not identical. This config requires both " +
1544+
s"${V2_BUCKETING_ENABLED.key} and ${V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key} to be " +
1545+
s"enabled and ${V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
1546+
"to be disabled."
1547+
)
1548+
.version("4.0.0")
1549+
.booleanConf
1550+
.createWithDefault(false)
1551+
15401552
val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
15411553
.doc("The maximum number of buckets allowed.")
15421554
.version("2.4.0")
@@ -5201,6 +5213,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
52015213
def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean =
52025214
getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)
52035215

5216+
def v2BucketingAllowCompatibleTransforms: Boolean =
5217+
getConf(SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS)
5218+
52045219
def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
52055220
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)
52065221

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ import org.apache.spark.rdd.RDD
2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.plans.QueryPlan
27-
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition}
27+
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning, SinglePartition}
2828
import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper}
2929
import org.apache.spark.sql.connector.catalog.Table
30+
import org.apache.spark.sql.connector.catalog.functions.Reducer
3031
import org.apache.spark.sql.connector.read._
3132
import org.apache.spark.util.ArrayImplicits._
3233

@@ -164,6 +165,18 @@ case class BatchScanExec(
164165
(groupedParts, expressions)
165166
}
166167

168+
// Also re-group the partitions if we are reducing compatible partition expressions
169+
val finalGroupedPartitions = spjParams.reducers match {
170+
case Some(reducers) =>
171+
val result = groupedPartitions.groupBy { case (row, _) =>
172+
KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers)
173+
}.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq
174+
val rowOrdering = RowOrdering.createNaturalAscendingOrdering(
175+
expressions.map(_.dataType))
176+
result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1))
177+
case _ => groupedPartitions
178+
}
179+
167180
// When partially clustered, the input partitions are not grouped by partition
168181
// values. Here we'll need to check `commonPartitionValues` and decide how to group
169182
// and replicate splits within a partition.
@@ -174,7 +187,7 @@ case class BatchScanExec(
174187
.get
175188
.map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2))
176189
.toMap
177-
val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) =>
190+
val nestGroupedPartitions = finalGroupedPartitions.map { case (partValue, splits) =>
178191
// `commonPartValuesMap` should contain the part value since it's the super set.
179192
val numSplits = commonPartValuesMap
180193
.get(InternalRowComparableWrapper(partValue, partExpressions))
@@ -207,7 +220,7 @@ case class BatchScanExec(
207220
} else {
208221
// either `commonPartitionValues` is not defined, or it is defined but
209222
// `applyPartialClustering` is false.
210-
val partitionMapping = groupedPartitions.map { case (partValue, splits) =>
223+
val partitionMapping = finalGroupedPartitions.map { case (partValue, splits) =>
211224
InternalRowComparableWrapper(partValue, partExpressions) -> splits
212225
}.toMap
213226

@@ -224,7 +237,6 @@ case class BatchScanExec(
224237

225238
case _ => filteredPartitions
226239
}
227-
228240
new DataSourceRDD(
229241
sparkContext, finalPartitions, readerFactory, supportsColumnar, customMetrics)
230242
}
@@ -259,6 +271,7 @@ case class StoragePartitionJoinParams(
259271
keyGroupedPartitioning: Option[Seq[Expression]] = None,
260272
joinKeyPositions: Option[Seq[Int]] = None,
261273
commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
274+
reducers: Option[Seq[Option[Reducer[Any]]]] = None,
262275
applyPartialClustering: Boolean = false,
263276
replicatePartitions: Boolean = false) {
264277
override def equals(other: Any): Boolean = other match {

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._
2626
import org.apache.spark.sql.catalyst.plans.physical._
2727
import org.apache.spark.sql.catalyst.rules.Rule
2828
import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
29+
import org.apache.spark.sql.connector.catalog.functions.Reducer
2930
import org.apache.spark.sql.execution._
3031
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
3132
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
@@ -505,11 +506,28 @@ case class EnsureRequirements(
505506
}
506507
}
507508

508-
// Now we need to push-down the common partition key to the scan in each child
509-
newLeft = populatePartitionValues(left, mergedPartValues, leftSpec.joinKeyPositions,
510-
applyPartialClustering, replicateLeftSide)
511-
newRight = populatePartitionValues(right, mergedPartValues, rightSpec.joinKeyPositions,
512-
applyPartialClustering, replicateRightSide)
509+
// in case of compatible but not identical partition expressions, we apply 'reduce'
510+
// transforms to group one side's partitions as well as the common partition values
511+
val leftReducers = leftSpec.reducers(rightSpec)
512+
val rightReducers = rightSpec.reducers(leftSpec)
513+
514+
if (leftReducers.isDefined || rightReducers.isDefined) {
515+
mergedPartValues = reduceCommonPartValues(mergedPartValues,
516+
leftSpec.partitioning.expressions,
517+
leftReducers)
518+
mergedPartValues = reduceCommonPartValues(mergedPartValues,
519+
rightSpec.partitioning.expressions,
520+
rightReducers)
521+
val rowOrdering = RowOrdering
522+
.createNaturalAscendingOrdering(partitionExprs.map(_.dataType))
523+
mergedPartValues = mergedPartValues.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1))
524+
}
525+
526+
// Now we need to push-down the common partition information to the scan in each child
527+
newLeft = populateCommonPartitionInfo(left, mergedPartValues, leftSpec.joinKeyPositions,
528+
leftReducers, applyPartialClustering, replicateLeftSide)
529+
newRight = populateCommonPartitionInfo(right, mergedPartValues, rightSpec.joinKeyPositions,
530+
rightReducers, applyPartialClustering, replicateRightSide)
513531
}
514532
}
515533

@@ -527,25 +545,38 @@ case class EnsureRequirements(
527545
joinType == LeftAnti || joinType == LeftOuter
528546
}
529547

530-
// Populate the common partition values down to the scan nodes
531-
private def populatePartitionValues(
548+
// Populate the common partition information down to the scan nodes
549+
private def populateCommonPartitionInfo(
532550
plan: SparkPlan,
533551
values: Seq[(InternalRow, Int)],
534552
joinKeyPositions: Option[Seq[Int]],
553+
reducers: Option[Seq[Option[Reducer[Any]]]],
535554
applyPartialClustering: Boolean,
536555
replicatePartitions: Boolean): SparkPlan = plan match {
537556
case scan: BatchScanExec =>
538557
scan.copy(
539558
spjParams = scan.spjParams.copy(
540559
commonPartitionValues = Some(values),
541560
joinKeyPositions = joinKeyPositions,
561+
reducers = reducers,
542562
applyPartialClustering = applyPartialClustering,
543563
replicatePartitions = replicatePartitions
544564
)
545565
)
546566
case node =>
547-
node.mapChildren(child => populatePartitionValues(
548-
child, values, joinKeyPositions, applyPartialClustering, replicatePartitions))
567+
node.mapChildren(child => populateCommonPartitionInfo(
568+
child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions))
569+
}
570+
571+
private def reduceCommonPartValues(commonPartValues: Seq[(InternalRow, Int)],
572+
expressions: Seq[Expression],
573+
reducers: Option[Seq[Option[Reducer[Any]]]]) = {
574+
reducers match {
575+
case Some(reducers) => commonPartValues.groupBy { case (row, _) =>
576+
KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers)
577+
}.map{ case(wrapper, splits) => (wrapper.row, splits.map(_._2).sum) }.toSeq
578+
case _ => commonPartValues
579+
}
549580
}
550581

551582
/**

0 commit comments

Comments
 (0)