diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/BroadcastJoinOuterJoinStreamSide.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/BroadcastJoinOuterJoinStreamSide.scala index 5aab2f3b6d358..c6d6a9adf72e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/BroadcastJoinOuterJoinStreamSide.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/BroadcastJoinOuterJoinStreamSide.scala @@ -36,7 +36,8 @@ object BroadcastJoinOuterJoinStreamSide extends Rule[LogicalPlan] with JoinSelec LeftOuter | LeftSemi | LeftAnti, _, _) => j case j @ ExtractEquiJoinKeys(LeftOuter | LeftSemi | LeftAnti, - leftKeys, _, None, left, right, hint) if leftKeys.nonEmpty && muchSmaller(left, right) && + leftKeys, _, None, left, right, hint) + if leftKeys.nonEmpty && muchSmaller(left, right, conf) && !(hintToBroadcastRight(hint) || canBroadcastBySize(right, conf)) && (hintToBroadcastLeft(hint) || canBroadcastBySize(left, conf)) => logInfo("BroadcastJoinOuterJoinStreamSide detected.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 70c60f5124551..60e5569480419 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils /** * Reorder the joins and push all the conditions into join, so that the bottom ones have at least @@ -270,28 +271,18 @@ trait JoinSelectionHelper { val buildLeft = if (hintOnly) { hintToShuffleHashJoinLeft(hint) } else { - if (hintToPreferShuffleHashJoinLeft(hint)) { - true - } else { - if (!conf.preferSortMergeJoin) { - canBuildLocalHashMapBySize(left, conf) && muchSmaller(left, right) - } else { - false - } - } + hintToPreferShuffleHashJoinLeft(hint) || + (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(left, conf) && + muchSmaller(left, right)) || + forceApplyShuffledHashJoin(conf) } val buildRight = if (hintOnly) { hintToShuffleHashJoinRight(hint) } else { - if (hintToPreferShuffleHashJoinRight(hint)) { - true - } else { - if (!conf.preferSortMergeJoin) { - canBuildLocalHashMapBySize(right, conf) && muchSmaller(right, left) - } else { - false - } - } + hintToPreferShuffleHashJoinRight(hint) || + (!conf.preferSortMergeJoin && canBuildLocalHashMapBySize(right, conf) && + muchSmaller(right, left)) || + forceApplyShuffledHashJoin(conf) } getBuildSide( canBuildShuffledHashJoinLeft(joinType) && buildLeft, @@ -435,8 +426,8 @@ trait JoinSelectionHelper { * that is much smaller than other one. Since we does not have the statistic for number of rows, * use the size of bytes here as estimation. */ - def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = { - a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes + def muchSmaller(a: LogicalPlan, b: LogicalPlan, conf: SQLConf): Boolean = { + a.stats.sizeInBytes * conf.getConf(SQLConf.SHUFFLE_HASH_JOIN_FACTOR) <= b.stats.sizeInBytes } def canBroadcastTokenTree(left: LogicalPlan, @@ -460,5 +451,14 @@ trait JoinSelectionHelper { right.stats.sizeInBytes <= conf.containsJoinThreshold && !hintToNotBroadcastRight(hint) } + + /** + * Returns whether a shuffled hash join should be force applied. + * The config key is hard-coded because it's testing only and should not be exposed. + */ + private def forceApplyShuffledHashJoin(conf: SQLConf): Boolean = { + Utils.isTesting && + conf.getConfString("spark.sql.join.forceApplyShuffledHashJoin", "false") == "true" + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/join.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/join.sql index cc07b00cc3670..1a8ecbbe4178a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/join.sql @@ -13,7 +13,7 @@ --CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=10485760 --CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true ---CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false +--CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.forceApplyShuffledHashJoin=true --CONFIG_DIM2 spark.sql.codegen.wholeStage=true --CONFIG_DIM2 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index ecd3e3c44c84d..876ef38f3b8f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1603,4 +1603,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan } } } + + test("SPARK-35984: Config to force applying shuffled hash join") { + val sql = "SELECT * FROM testData JOIN testData2 ON key = a" + assertJoin(sql, classOf[SortMergeJoinExec]) + withSQLConf("spark.sql.join.forceApplyShuffledHashJoin" -> "true") { + assertJoin(sql, classOf[ShuffledHashJoinExec]) + } + } }