Skip to content

Commit

Permalink
Second round review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
szehon-ho committed Mar 19, 2024
1 parent 23c580f commit 7a25afe
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,20 @@
/**
* A 'reducer' for output of user-defined functions.
*
* A user_defined function f_source(x) is 'reducible' on another user_defined function f_target(x),
* if there exists a 'reducer' r(x) such that r(f_source(x)) = f_target(x) for all input x.
* @see ReducibleFunction
*
* A user defined function f_source(x) is 'reducible' on another user_defined function f_target(x) if
* <ul>
* <li> There exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x. </li>
* <li> More generally, there exists two reducer functions r1(x) and r2(x) such that
* r1(f_source(x)) = r2(f_target(x)) for all input x. </li>
* </ul>
*
* @param <I> reducer input type
* @param <O> reducer output type
* @since 4.0.0
*/
@Evolving
public interface Reducer<I, O> {
O reduce(I arg1);
O reduce(I arg);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,34 @@
package org.apache.spark.sql.connector.catalog.functions;

import org.apache.spark.annotation.Evolving;
import scala.Option;

/**
* Base class for user-defined functions that can be 'reduced' on another function.
*
* A function f_source(x) is 'reducible' on another function f_target(x) if
* there exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x.
*
* <ul>
* <li> There exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x. </li>
* <li> More generally, there exists two reducer functions r1(x) and r2(x) such that
* r1(f_source(x)) = r2(f_target(x)) for all input x. </li>
* </ul>
* <p>
* Examples:
* <ul>
* <li>Bucket functions
* <li>Bucket functions where one side has reducer
* <ul>
* <li>f_source(x) = bucket(4, x)</li>
* <li>f_target(x) = bucket(2, x)</li>
* <li>r(x) = x / 2</li>
* <li>r(x) = x % 2</li>
* </ul>
*
* <li>Bucket functions where both sides have reducer
* <ul>
* <li>f_source(x) = bucket(16, x)</li>
* <li>f_target(x) = bucket(12, x)</li>
* <li>r1(x) = x % 4</li>
* <li>r2(x) = x % 4</li>
* </ul>
*
* <li>Date functions
* <ul>
* <li>f_source(x) = days(x)</li>
Expand All @@ -49,24 +60,42 @@
public interface ReducibleFunction<I, O> {

/**
* If this function is 'reducible' on another function, return the {@link Reducer} function.
* This method is for bucket functions.
*
* If this bucket function is 'reducible' on another bucket function, return the {@link Reducer} function.
* <p>
* Example:
* Example to return reducer for reducing f_source = bucket(4, x) on f_target = bucket(2, x)
* <ul>
* <li>this_function = bucket(4, x)
* <li>other function = bucket(2, x)
* <li>thisFunction = bucket</li>
* <li>otherFunction = bucket</li>
* <li>thisNumBuckets = Int(4)</li>
* <li>otherNumBuckets = Int(2)</li>
* </ul>
* Invoke with arguments
*
* @param otherFunction the other bucket function
* @param thisNumBuckets number of buckets for this bucket function
* @param otherNumBuckets number of buckets for the other bucket function
* @return a reduction function if it is reducible, null if not
*/
default Reducer<I, O> reducer(ReducibleFunction<?, ?> otherFunction, int thisNumBuckets, int otherNumBuckets) {
return reducer(otherFunction);
}

/**
* This method is for all other functions.
*
* If this function is 'reducible' on another function, return the {@link Reducer} function.
* <p>
* Example of reducing f_source = days(x) on f_target = hours(x)
* <ul>
* <li>other = bucket</li>
* <li>this param = Int(4)</li>
* <li>other param = Int(2)</li>
* <li>thisFunction = days</li>
* <li>otherFunction = hours</li>
* </ul>
* @param other the other function
* @param thisParam param for this function
* @param otherParam param for the other function
* @return a reduction function if it is reducible, none if not
*
* @param otherFunction the other function
* @return a reduction function if it is reducible, null if not.
*/
Option<Reducer<I, O>> reducer(ReducibleFunction<?, ?> other, Option<?> thisParam,
Option<?> otherParam);
default Reducer<I, O> reducer(ReducibleFunction<?, ?> otherFunction) {
return reducer(otherFunction, 0, 0);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ case class TransformExpression(
} else {
(function, other.function) match {
case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) =>
val reducer = f.reducer(o, numBucketsOpt, other.numBucketsOpt)
val otherReducer = o.reducer(f, other.numBucketsOpt, numBucketsOpt)
reducer.isDefined || otherReducer.isDefined
val reducer = f.reducer(o, numBucketsOpt.getOrElse(0), other.numBucketsOpt.getOrElse(0))
val otherReducer =
o.reducer(f, other.numBucketsOpt.getOrElse(0), numBucketsOpt.getOrElse(0))
reducer != null || otherReducer != null
case _ => false
}
}
Expand All @@ -90,7 +91,10 @@ case class TransformExpression(
def reducers(other: TransformExpression): Option[Reducer[_, _]] = {
(function, other.function) match {
case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) =>
e1.reducer(e2, numBucketsOpt, other.numBucketsOpt)
val reducer = e1.reducer(e2,
numBucketsOpt.getOrElse(0),
other.numBucketsOpt.getOrElse(0))
Option(reducer)
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -851,25 +851,23 @@ case class KeyGroupedShuffleSpec(
* A [[Reducer]] exists for a partition expression function of this shuffle spec if it is
* 'reducible' on the corresponding partition expression function of the other shuffle spec.
* <p>
* If a value is returned, there must be one Option[[Reducer]] per partition expression.
* If a value is returned, there must be one [[Reducer]] per partition expression.
* A None value in the set indicates that the particular partition expression is not reducible
* on the corresponding expression on the other shuffle spec.
* <p>
* Returning none also indicates that none of the partition expressions can be reduced on the
* corresponding expression on the other shuffle spec.
*
* @param other other key-grouped shuffle spec
*/
def reducers(other: ShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = {
other match {
case otherSpec: KeyGroupedShuffleSpec =>
val results = partitioning.expressions.zip(otherSpec.partitioning.expressions).map {
case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2)
case (_, _) => None
}
def reducers(other: KeyGroupedShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = {
val results = partitioning.expressions.zip(other.partitioning.expressions).map {
case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2)
case (_, _) => None
}

// optimize to not return a value, if none of the partition expressions are reducible
if (results.forall(p => p.isEmpty)) None else Some(results)
case _ => None
}
// optimize to not return a value, if none of the partition expressions are reducible
if (results.forall(p => p.isEmpty)) None else Some(results)
}

override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
Expand All @@ -883,8 +881,8 @@ case class KeyGroupedShuffleSpec(

object KeyGroupedShuffleSpec {
def reducePartitionValue(row: InternalRow,
expressions: Seq[Expression],
reducers: Seq[Option[Reducer[_, _]]]):
expressions: Seq[Expression],
reducers: Seq[Option[Reducer[_, _]]]):
InternalRowComparableWrapper = {
val partitionVals = row.toSeq(expressions.map(_.dataType))
val reducedRow = partitionVals.zip(reducers).map{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,168 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
}
}

test("SPARK-47094: Support compatible buckets with common divisor") {
val table1 = "tab1e1"
val table2 = "table2"

Seq(
((6, 4), (4, 6)),
((6, 6), (4, 4)),
((4, 4), (6, 6)),
((4, 6), (6, 4))).foreach {
case ((table1buckets1, table1buckets2), (table2buckets1, table2buckets2)) =>
catalog.clearTables()

val partition1 = Array(bucket(table1buckets1, "store_id"),
bucket(table1buckets2, "dept_id"))
val partition2 = Array(bucket(table2buckets1, "store_id"),
bucket(table2buckets2, "dept_id"))

Seq((table1, partition1), (table2, partition2)).foreach { case (tab, part) =>
createTable(tab, columns2, part)
val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " +
"(0, 0, 'aa'), " +
"(0, 0, 'ab'), " + // duplicate partition key
"(0, 1, 'ac'), " +
"(0, 2, 'ad'), " +
"(0, 3, 'ae'), " +
"(0, 4, 'af'), " +
"(0, 5, 'ag'), " +
"(1, 0, 'ah'), " +
"(1, 0, 'ai'), " + // duplicate partition key
"(1, 1, 'aj'), " +
"(1, 2, 'ak'), " +
"(1, 3, 'al'), " +
"(1, 4, 'am'), " +
"(1, 5, 'an'), " +
"(2, 0, 'ao'), " +
"(2, 0, 'ap'), " + // duplicate partition key
"(2, 1, 'aq'), " +
"(2, 2, 'ar'), " +
"(2, 3, 'as'), " +
"(2, 4, 'at'), " +
"(2, 5, 'au'), " +
"(3, 0, 'av'), " +
"(3, 0, 'aw'), " + // duplicate partition key
"(3, 1, 'ax'), " +
"(3, 2, 'ay'), " +
"(3, 3, 'az'), " +
"(3, 4, 'ba'), " +
"(3, 5, 'bb'), " +
"(4, 0, 'bc'), " +
"(4, 0, 'bd'), " + // duplicate partition key
"(4, 1, 'be'), " +
"(4, 2, 'bf'), " +
"(4, 3, 'bg'), " +
"(4, 4, 'bh'), " +
"(4, 5, 'bi'), " +
"(5, 0, 'bj'), " +
"(5, 0, 'bk'), " + // duplicate partition key
"(5, 1, 'bl'), " +
"(5, 2, 'bm'), " +
"(5, 3, 'bn'), " +
"(5, 4, 'bo'), " +
"(5, 5, 'bp')"

// additional unmatched partitions to test push down
val finalStr = if (tab == table1) {
insertStr ++ ", (8, 0, 'xa'), (8, 8, 'xx')"
} else {
insertStr ++ ", (9, 0, 'ya'), (9, 9, 'yy')"
}

sql(finalStr)
}

Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys =>
withSQLConf(
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false",
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key ->
allowJoinKeysSubsetOfPartitionKeys.toString,
SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
val df = sql(
s"""
|${selectWithMergeJoinHint("t1", "t2")}
|t1.store_id, t1.dept_id, t1.data, t2.data
|FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2
|ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id
|ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data
|""".stripMargin)

val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.isEmpty, "SPJ should be triggered")

val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD.
partitions.length)

def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt
val expectedBuckets = gcd(table1buckets1, table2buckets1) *
gcd(table1buckets2, table2buckets2)
assert(scans == Seq(expectedBuckets, expectedBuckets))

checkAnswer(df, Seq(
Row(0, 0, "aa", "aa"),
Row(0, 0, "aa", "ab"),
Row(0, 0, "ab", "aa"),
Row(0, 0, "ab", "ab"),
Row(0, 1, "ac", "ac"),
Row(0, 2, "ad", "ad"),
Row(0, 3, "ae", "ae"),
Row(0, 4, "af", "af"),
Row(0, 5, "ag", "ag"),
Row(1, 0, "ah", "ah"),
Row(1, 0, "ah", "ai"),
Row(1, 0, "ai", "ah"),
Row(1, 0, "ai", "ai"),
Row(1, 1, "aj", "aj"),
Row(1, 2, "ak", "ak"),
Row(1, 3, "al", "al"),
Row(1, 4, "am", "am"),
Row(1, 5, "an", "an"),
Row(2, 0, "ao", "ao"),
Row(2, 0, "ao", "ap"),
Row(2, 0, "ap", "ao"),
Row(2, 0, "ap", "ap"),
Row(2, 1, "aq", "aq"),
Row(2, 2, "ar", "ar"),
Row(2, 3, "as", "as"),
Row(2, 4, "at", "at"),
Row(2, 5, "au", "au"),
Row(3, 0, "av", "av"),
Row(3, 0, "av", "aw"),
Row(3, 0, "aw", "av"),
Row(3, 0, "aw", "aw"),
Row(3, 1, "ax", "ax"),
Row(3, 2, "ay", "ay"),
Row(3, 3, "az", "az"),
Row(3, 4, "ba", "ba"),
Row(3, 5, "bb", "bb"),
Row(4, 0, "bc", "bc"),
Row(4, 0, "bc", "bd"),
Row(4, 0, "bd", "bc"),
Row(4, 0, "bd", "bd"),
Row(4, 1, "be", "be"),
Row(4, 2, "bf", "bf"),
Row(4, 3, "bg", "bg"),
Row(4, 4, "bh", "bh"),
Row(4, 5, "bi", "bi"),
Row(5, 0, "bj", "bj"),
Row(5, 0, "bj", "bk"),
Row(5, 0, "bk", "bj"),
Row(5, 0, "bk", "bk"),
Row(5, 1, "bl", "bl"),
Row(5, 2, "bm", "bm"),
Row(5, 3, "bn", "bn"),
Row(5, 4, "bo", "bo"),
Row(5, 5, "bp", "bp")
))
}
}
}
}

test("SPARK-47094: Support compatible buckets with less join keys than partition keys") {
val table1 = "tab1e1"
val table2 = "table2"
Expand Down
Loading

0 comments on commit 7a25afe

Please sign in to comment.