Skip to content

Commit 7a25afe

Browse files
committed
Second round review comments
1 parent 23c580f commit 7a25afe

File tree

6 files changed

+261
-51
lines changed

6 files changed

+261
-51
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,20 @@
2121
/**
2222
* A 'reducer' for output of user-defined functions.
2323
*
24-
* A user_defined function f_source(x) is 'reducible' on another user_defined function f_target(x),
25-
* if there exists a 'reducer' r(x) such that r(f_source(x)) = f_target(x) for all input x.
24+
* @see ReducibleFunction
25+
*
26+
* A user defined function f_source(x) is 'reducible' on another user_defined function f_target(x) if
27+
* <ul>
28+
* <li> There exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x. </li>
29+
* <li> More generally, there exists two reducer functions r1(x) and r2(x) such that
30+
* r1(f_source(x)) = r2(f_target(x)) for all input x. </li>
31+
* </ul>
32+
*
2633
* @param <I> reducer input type
2734
* @param <O> reducer output type
2835
* @since 4.0.0
2936
*/
3037
@Evolving
3138
public interface Reducer<I, O> {
32-
O reduce(I arg1);
39+
O reduce(I arg);
3340
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,34 @@
1717
package org.apache.spark.sql.connector.catalog.functions;
1818

1919
import org.apache.spark.annotation.Evolving;
20-
import scala.Option;
2120

2221
/**
2322
* Base class for user-defined functions that can be 'reduced' on another function.
2423
*
2524
* A function f_source(x) is 'reducible' on another function f_target(x) if
26-
* there exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x.
27-
*
25+
* <ul>
26+
* <li> There exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x. </li>
27+
* <li> More generally, there exists two reducer functions r1(x) and r2(x) such that
28+
* r1(f_source(x)) = r2(f_target(x)) for all input x. </li>
29+
* </ul>
2830
* <p>
2931
* Examples:
3032
* <ul>
31-
* <li>Bucket functions
33+
* <li>Bucket functions where one side has reducer
3234
* <ul>
3335
* <li>f_source(x) = bucket(4, x)</li>
3436
* <li>f_target(x) = bucket(2, x)</li>
35-
* <li>r(x) = x / 2</li>
37+
* <li>r(x) = x % 2</li>
3638
* </ul>
39+
*
40+
* <li>Bucket functions where both sides have reducer
41+
* <ul>
42+
* <li>f_source(x) = bucket(16, x)</li>
43+
* <li>f_target(x) = bucket(12, x)</li>
44+
* <li>r1(x) = x % 4</li>
45+
* <li>r2(x) = x % 4</li>
46+
* </ul>
47+
*
3748
* <li>Date functions
3849
* <ul>
3950
* <li>f_source(x) = days(x)</li>
@@ -49,24 +60,42 @@
4960
public interface ReducibleFunction<I, O> {
5061

5162
/**
52-
* If this function is 'reducible' on another function, return the {@link Reducer} function.
63+
* This method is for bucket functions.
64+
*
65+
* If this bucket function is 'reducible' on another bucket function, return the {@link Reducer} function.
5366
* <p>
54-
* Example:
67+
* Example to return reducer for reducing f_source = bucket(4, x) on f_target = bucket(2, x)
5568
* <ul>
56-
* <li>this_function = bucket(4, x)
57-
* <li>other function = bucket(2, x)
69+
* <li>thisFunction = bucket</li>
70+
* <li>otherFunction = bucket</li>
71+
* <li>thisNumBuckets = Int(4)</li>
72+
* <li>otherNumBuckets = Int(2)</li>
5873
* </ul>
59-
* Invoke with arguments
74+
*
75+
* @param otherFunction the other bucket function
76+
* @param thisNumBuckets number of buckets for this bucket function
77+
* @param otherNumBuckets number of buckets for the other bucket function
78+
* @return a reduction function if it is reducible, null if not
79+
*/
80+
default Reducer<I, O> reducer(ReducibleFunction<?, ?> otherFunction, int thisNumBuckets, int otherNumBuckets) {
81+
return reducer(otherFunction);
82+
}
83+
84+
/**
85+
* This method is for all other functions.
86+
*
87+
* If this function is 'reducible' on another function, return the {@link Reducer} function.
88+
* <p>
89+
* Example of reducing f_source = days(x) on f_target = hours(x)
6090
* <ul>
61-
* <li>other = bucket</li>
62-
* <li>this param = Int(4)</li>
63-
* <li>other param = Int(2)</li>
91+
* <li>thisFunction = days</li>
92+
* <li>otherFunction = hours</li>
6493
* </ul>
65-
* @param other the other function
66-
* @param thisParam param for this function
67-
* @param otherParam param for the other function
68-
* @return a reduction function if it is reducible, none if not
94+
*
95+
* @param otherFunction the other function
96+
* @return a reduction function if it is reducible, null if not.
6997
*/
70-
Option<Reducer<I, O>> reducer(ReducibleFunction<?, ?> other, Option<?> thisParam,
71-
Option<?> otherParam);
98+
default Reducer<I, O> reducer(ReducibleFunction<?, ?> otherFunction) {
99+
return reducer(otherFunction, 0, 0);
100+
}
72101
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,10 @@ case class TransformExpression(
7070
} else {
7171
(function, other.function) match {
7272
case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) =>
73-
val reducer = f.reducer(o, numBucketsOpt, other.numBucketsOpt)
74-
val otherReducer = o.reducer(f, other.numBucketsOpt, numBucketsOpt)
75-
reducer.isDefined || otherReducer.isDefined
73+
val reducer = f.reducer(o, numBucketsOpt.getOrElse(0), other.numBucketsOpt.getOrElse(0))
74+
val otherReducer =
75+
o.reducer(f, other.numBucketsOpt.getOrElse(0), numBucketsOpt.getOrElse(0))
76+
reducer != null || otherReducer != null
7677
case _ => false
7778
}
7879
}
@@ -90,7 +91,10 @@ case class TransformExpression(
9091
def reducers(other: TransformExpression): Option[Reducer[_, _]] = {
9192
(function, other.function) match {
9293
case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) =>
93-
e1.reducer(e2, numBucketsOpt, other.numBucketsOpt)
94+
val reducer = e1.reducer(e2,
95+
numBucketsOpt.getOrElse(0),
96+
other.numBucketsOpt.getOrElse(0))
97+
Option(reducer)
9498
case _ => None
9599
}
96100
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -851,25 +851,23 @@ case class KeyGroupedShuffleSpec(
851851
* A [[Reducer]] exists for a partition expression function of this shuffle spec if it is
852852
* 'reducible' on the corresponding partition expression function of the other shuffle spec.
853853
* <p>
854-
* If a value is returned, there must be one Option[[Reducer]] per partition expression.
854+
* If a value is returned, there must be one [[Reducer]] per partition expression.
855855
* A None value in the set indicates that the particular partition expression is not reducible
856856
* on the corresponding expression on the other shuffle spec.
857857
* <p>
858858
* Returning none also indicates that none of the partition expressions can be reduced on the
859859
* corresponding expression on the other shuffle spec.
860+
*
861+
* @param other other key-grouped shuffle spec
860862
*/
861-
def reducers(other: ShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = {
862-
other match {
863-
case otherSpec: KeyGroupedShuffleSpec =>
864-
val results = partitioning.expressions.zip(otherSpec.partitioning.expressions).map {
865-
case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2)
866-
case (_, _) => None
867-
}
863+
def reducers(other: KeyGroupedShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = {
864+
val results = partitioning.expressions.zip(other.partitioning.expressions).map {
865+
case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2)
866+
case (_, _) => None
867+
}
868868

869-
// optimize to not return a value, if none of the partition expressions are reducible
870-
if (results.forall(p => p.isEmpty)) None else Some(results)
871-
case _ => None
872-
}
869+
// optimize to not return a value, if none of the partition expressions are reducible
870+
if (results.forall(p => p.isEmpty)) None else Some(results)
873871
}
874872

875873
override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
@@ -883,8 +881,8 @@ case class KeyGroupedShuffleSpec(
883881

884882
object KeyGroupedShuffleSpec {
885883
def reducePartitionValue(row: InternalRow,
886-
expressions: Seq[Expression],
887-
reducers: Seq[Option[Reducer[_, _]]]):
884+
expressions: Seq[Expression],
885+
reducers: Seq[Option[Reducer[_, _]]]):
888886
InternalRowComparableWrapper = {
889887
val partitionVals = row.toSeq(expressions.map(_.dataType))
890888
val reducedRow = partitionVals.zip(reducers).map{

sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,6 +1475,168 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
14751475
}
14761476
}
14771477

1478+
test("SPARK-47094: Support compatible buckets with common divisor") {
1479+
val table1 = "tab1e1"
1480+
val table2 = "table2"
1481+
1482+
Seq(
1483+
((6, 4), (4, 6)),
1484+
((6, 6), (4, 4)),
1485+
((4, 4), (6, 6)),
1486+
((4, 6), (6, 4))).foreach {
1487+
case ((table1buckets1, table1buckets2), (table2buckets1, table2buckets2)) =>
1488+
catalog.clearTables()
1489+
1490+
val partition1 = Array(bucket(table1buckets1, "store_id"),
1491+
bucket(table1buckets2, "dept_id"))
1492+
val partition2 = Array(bucket(table2buckets1, "store_id"),
1493+
bucket(table2buckets2, "dept_id"))
1494+
1495+
Seq((table1, partition1), (table2, partition2)).foreach { case (tab, part) =>
1496+
createTable(tab, columns2, part)
1497+
val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " +
1498+
"(0, 0, 'aa'), " +
1499+
"(0, 0, 'ab'), " + // duplicate partition key
1500+
"(0, 1, 'ac'), " +
1501+
"(0, 2, 'ad'), " +
1502+
"(0, 3, 'ae'), " +
1503+
"(0, 4, 'af'), " +
1504+
"(0, 5, 'ag'), " +
1505+
"(1, 0, 'ah'), " +
1506+
"(1, 0, 'ai'), " + // duplicate partition key
1507+
"(1, 1, 'aj'), " +
1508+
"(1, 2, 'ak'), " +
1509+
"(1, 3, 'al'), " +
1510+
"(1, 4, 'am'), " +
1511+
"(1, 5, 'an'), " +
1512+
"(2, 0, 'ao'), " +
1513+
"(2, 0, 'ap'), " + // duplicate partition key
1514+
"(2, 1, 'aq'), " +
1515+
"(2, 2, 'ar'), " +
1516+
"(2, 3, 'as'), " +
1517+
"(2, 4, 'at'), " +
1518+
"(2, 5, 'au'), " +
1519+
"(3, 0, 'av'), " +
1520+
"(3, 0, 'aw'), " + // duplicate partition key
1521+
"(3, 1, 'ax'), " +
1522+
"(3, 2, 'ay'), " +
1523+
"(3, 3, 'az'), " +
1524+
"(3, 4, 'ba'), " +
1525+
"(3, 5, 'bb'), " +
1526+
"(4, 0, 'bc'), " +
1527+
"(4, 0, 'bd'), " + // duplicate partition key
1528+
"(4, 1, 'be'), " +
1529+
"(4, 2, 'bf'), " +
1530+
"(4, 3, 'bg'), " +
1531+
"(4, 4, 'bh'), " +
1532+
"(4, 5, 'bi'), " +
1533+
"(5, 0, 'bj'), " +
1534+
"(5, 0, 'bk'), " + // duplicate partition key
1535+
"(5, 1, 'bl'), " +
1536+
"(5, 2, 'bm'), " +
1537+
"(5, 3, 'bn'), " +
1538+
"(5, 4, 'bo'), " +
1539+
"(5, 5, 'bp')"
1540+
1541+
// additional unmatched partitions to test push down
1542+
val finalStr = if (tab == table1) {
1543+
insertStr ++ ", (8, 0, 'xa'), (8, 8, 'xx')"
1544+
} else {
1545+
insertStr ++ ", (9, 0, 'ya'), (9, 9, 'yy')"
1546+
}
1547+
1548+
sql(finalStr)
1549+
}
1550+
1551+
Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys =>
1552+
withSQLConf(
1553+
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
1554+
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
1555+
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false",
1556+
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key ->
1557+
allowJoinKeysSubsetOfPartitionKeys.toString,
1558+
SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
1559+
val df = sql(
1560+
s"""
1561+
|${selectWithMergeJoinHint("t1", "t2")}
1562+
|t1.store_id, t1.dept_id, t1.data, t2.data
1563+
|FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2
1564+
|ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id
1565+
|ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data
1566+
|""".stripMargin)
1567+
1568+
val shuffles = collectShuffles(df.queryExecution.executedPlan)
1569+
assert(shuffles.isEmpty, "SPJ should be triggered")
1570+
1571+
val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD.
1572+
partitions.length)
1573+
1574+
def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt
1575+
val expectedBuckets = gcd(table1buckets1, table2buckets1) *
1576+
gcd(table1buckets2, table2buckets2)
1577+
assert(scans == Seq(expectedBuckets, expectedBuckets))
1578+
1579+
checkAnswer(df, Seq(
1580+
Row(0, 0, "aa", "aa"),
1581+
Row(0, 0, "aa", "ab"),
1582+
Row(0, 0, "ab", "aa"),
1583+
Row(0, 0, "ab", "ab"),
1584+
Row(0, 1, "ac", "ac"),
1585+
Row(0, 2, "ad", "ad"),
1586+
Row(0, 3, "ae", "ae"),
1587+
Row(0, 4, "af", "af"),
1588+
Row(0, 5, "ag", "ag"),
1589+
Row(1, 0, "ah", "ah"),
1590+
Row(1, 0, "ah", "ai"),
1591+
Row(1, 0, "ai", "ah"),
1592+
Row(1, 0, "ai", "ai"),
1593+
Row(1, 1, "aj", "aj"),
1594+
Row(1, 2, "ak", "ak"),
1595+
Row(1, 3, "al", "al"),
1596+
Row(1, 4, "am", "am"),
1597+
Row(1, 5, "an", "an"),
1598+
Row(2, 0, "ao", "ao"),
1599+
Row(2, 0, "ao", "ap"),
1600+
Row(2, 0, "ap", "ao"),
1601+
Row(2, 0, "ap", "ap"),
1602+
Row(2, 1, "aq", "aq"),
1603+
Row(2, 2, "ar", "ar"),
1604+
Row(2, 3, "as", "as"),
1605+
Row(2, 4, "at", "at"),
1606+
Row(2, 5, "au", "au"),
1607+
Row(3, 0, "av", "av"),
1608+
Row(3, 0, "av", "aw"),
1609+
Row(3, 0, "aw", "av"),
1610+
Row(3, 0, "aw", "aw"),
1611+
Row(3, 1, "ax", "ax"),
1612+
Row(3, 2, "ay", "ay"),
1613+
Row(3, 3, "az", "az"),
1614+
Row(3, 4, "ba", "ba"),
1615+
Row(3, 5, "bb", "bb"),
1616+
Row(4, 0, "bc", "bc"),
1617+
Row(4, 0, "bc", "bd"),
1618+
Row(4, 0, "bd", "bc"),
1619+
Row(4, 0, "bd", "bd"),
1620+
Row(4, 1, "be", "be"),
1621+
Row(4, 2, "bf", "bf"),
1622+
Row(4, 3, "bg", "bg"),
1623+
Row(4, 4, "bh", "bh"),
1624+
Row(4, 5, "bi", "bi"),
1625+
Row(5, 0, "bj", "bj"),
1626+
Row(5, 0, "bj", "bk"),
1627+
Row(5, 0, "bk", "bj"),
1628+
Row(5, 0, "bk", "bk"),
1629+
Row(5, 1, "bl", "bl"),
1630+
Row(5, 2, "bm", "bm"),
1631+
Row(5, 3, "bn", "bn"),
1632+
Row(5, 4, "bo", "bo"),
1633+
Row(5, 5, "bp", "bp")
1634+
))
1635+
}
1636+
}
1637+
}
1638+
}
1639+
14781640
test("SPARK-47094: Support compatible buckets with less join keys than partition keys") {
14791641
val table1 = "tab1e1"
14801642
val table2 = "table2"

0 commit comments

Comments
 (0)