Skip to content

Commit d4b423c

Browse files
committed
refactor
1 parent 29eea67 commit d4b423c

File tree

6 files changed

+106
-0
lines changed

6 files changed

+106
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2421,6 +2421,15 @@ object SQLConf {
24212421
.doubleConf
24222422
.createWithDefault(0.9)
24232423

2424+
val SWITCH_SORT_MERGE_JOIN_SIDES_ENABLED =
2425+
buildConf("spark.sql.switchSortMergeJoinSides.enabled")
2426+
.internal()
2427+
.doc("If true, switch the inner like join side for sort merge join according to the " +
2428+
"plan size and child unique keys.")
2429+
.version("3.4.0")
2430+
.booleanConf
2431+
.createWithDefault(true)
2432+
24242433
private def isValidTimezone(zone: String): Boolean = {
24252434
Try { DateTimeUtils.getZoneId(zone) }.isSuccess
24262435
}

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, Insert
3838
import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan}
3939
import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters
4040
import org.apache.spark.sql.execution.exchange.EnsureRequirements
41+
import org.apache.spark.sql.execution.joins.SwitchJoinSides
4142
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
4243
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata}
4344
import org.apache.spark.sql.internal.SQLConf
@@ -405,6 +406,7 @@ object QueryExecution {
405406
// as the original plan is hidden behind `AdaptiveSparkPlanExec`.
406407
adaptiveExecutionRule.toSeq ++
407408
Seq(
409+
SwitchJoinSides,
408410
CoalesceBucketsInJoin,
409411
PlanDynamicPruningFilters(sparkSession),
410412
PlanSubqueries(sparkSession),

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import org.apache.spark.sql.execution._
4141
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._
4242
import org.apache.spark.sql.execution.bucketing.DisableUnnecessaryBucketedScan
4343
import org.apache.spark.sql.execution.exchange._
44+
import org.apache.spark.sql.execution.joins.SwitchJoinSides
4445
import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric}
4546
import org.apache.spark.sql.internal.SQLConf
4647
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -116,6 +117,7 @@ case class AdaptiveSparkPlanExec(
116117
val ensureRequirements =
117118
EnsureRequirements(requiredDistribution.isDefined, requiredDistribution)
118119
Seq(
120+
SwitchJoinSides,
119121
RemoveRedundantProjects,
120122
ensureRequirements,
121123
ReplaceHashWithSortAgg,
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.joins
19+
20+
import org.apache.spark.sql.catalyst.expressions.ExpressionSet
21+
import org.apache.spark.sql.catalyst.plans.InnerLike
22+
import org.apache.spark.sql.catalyst.plans.logical.Join
23+
import org.apache.spark.sql.catalyst.rules.Rule
24+
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
25+
import org.apache.spark.sql.internal.SQLConf
26+
27+
/**
28+
* Switch Join sides if join satisfies:
29+
* - it's a inner like join
30+
* - it's physical plan is SortMergeJoinExec
31+
* - it's streamed side size is less than buffered
32+
* - it's streamed side is unique for join keys
33+
*/
34+
object SwitchJoinSides extends Rule[SparkPlan] {
35+
override def apply(plan: SparkPlan): SparkPlan = {
36+
if (!conf.getConf(SQLConf.SWITCH_SORT_MERGE_JOIN_SIDES_ENABLED)) {
37+
return plan
38+
}
39+
40+
plan transformUp {
41+
case j @ SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right, hint)
42+
if j.logicalLink.isDefined =>
43+
j.logicalLink.get match {
44+
case Join(logicalLeft, logicalRight, _: InnerLike, _, _)
45+
if logicalLeft.distinctKeys.exists(_.subsetOf(ExpressionSet(leftKeys))) &&
46+
logicalLeft.stats.sizeInBytes * 3 < logicalRight.stats.sizeInBytes =>
47+
ProjectExec(
48+
j.output,
49+
SortMergeJoinExec(rightKeys, leftKeys, joinType, condition, right, left, hint)
50+
)
51+
52+
case _ => j
53+
}
54+
}
55+
}
56+
}

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrd
3131
import org.apache.spark.sql.catalyst.plans.logical.Filter
3232
import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec}
3333
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
34+
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
3435
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
3536
import org.apache.spark.sql.execution.joins._
3637
import org.apache.spark.sql.execution.python.BatchEvalPythonExec
@@ -1440,4 +1441,22 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
14401441
}
14411442
}
14421443
}
1444+
1445+
test("SPARK-38887: Support switch inner join side for sort merge join") {
1446+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
1447+
val df1 = spark.range(2).selectExpr("id as c1")
1448+
val df2 = spark.range(100).selectExpr("id as c2")
1449+
val plan1 = df1.groupBy($"c1").agg($"c1").join(df2, $"c1" === $"c2", "inner")
1450+
.queryExecution.executedPlan
1451+
val smj1 = find(plan1)(_.isInstanceOf[SortMergeJoinExec]).get.asInstanceOf[SortMergeJoinExec]
1452+
assert(!smj1.left.exists(_.isInstanceOf[HashAggregateExec]))
1453+
assert(smj1.right.exists(_.isInstanceOf[HashAggregateExec]))
1454+
1455+
val plan2 = df2.groupBy($"c2").agg($"c2").join(df1, $"c1" === $"c2", "inner")
1456+
.queryExecution.executedPlan
1457+
val smj2 = find(plan2)(_.isInstanceOf[SortMergeJoinExec]).get.asInstanceOf[SortMergeJoinExec]
1458+
assert(smj2.left.exists(_.isInstanceOf[HashAggregateExec]))
1459+
assert(!smj2.right.exists(_.isInstanceOf[HashAggregateExec]))
1460+
}
1461+
}
14431462
}

sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,22 @@ object JoinBenchmark extends SqlBasedBenchmark {
149149
}
150150
}
151151

152+
def sortMergeJoinWithBufferedSideDuplicates(switch: Boolean): Unit = {
153+
val N1 = 2 << 20
154+
val N2 = 2 << 24
155+
withSQLConf(SQLConf.SWITCH_SORT_MERGE_JOIN_SIDES_ENABLED.key -> switch.toString) {
156+
codegenBenchmark(s"sort merge join with buffered side duplicates, switched: $switch,", N2) {
157+
val df1 = spark.range(N1).distinct()
158+
.selectExpr(s"id as k1")
159+
val df2 = spark.range(N2)
160+
.selectExpr(s"id % 1000 as k2")
161+
val df = df1.join(df2, col("k1") === col("k2"))
162+
assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[SortMergeJoinExec]))
163+
df.noop()
164+
}
165+
}
166+
}
167+
152168
def shuffleHashJoin(): Unit = {
153169
val N: Long = 4 << 20
154170
withSQLConf(
@@ -188,6 +204,8 @@ object JoinBenchmark extends SqlBasedBenchmark {
188204
broadcastHashJoinSemiJoinLongKey()
189205
sortMergeJoin()
190206
sortMergeJoinWithDuplicates()
207+
sortMergeJoinWithBufferedSideDuplicates(true)
208+
sortMergeJoinWithBufferedSideDuplicates(false)
191209
shuffleHashJoin()
192210
broadcastNestedLoopJoin()
193211
}

0 commit comments

Comments
 (0)