Skip to content

Commit e768eaa

Browse files
linzebingcloud-fan
authored andcommitted
[SPARK-34707][SQL] Code-gen broadcast nested loop join (left outer/right outer)
### What changes were proposed in this pull request? This PR is to add code-gen support for left outer (build right) and right outer (build left). Reference: `BroadcastNestedLoopJoinExec.codegenInner()` and `BroadcastNestedLoopJoinExec.outerJoin()` ### Why are the changes needed? Improve query CPU performance. Tested with a simple query: ```scala val N = 20 << 20 val M = 1 << 4 val dim = broadcast(spark.range(M).selectExpr("id as k2")) codegenBenchmark("left outer broadcast nested loop join", N) { val df = spark.range(N).selectExpr(s"id as k1").join( dim, col("k1") + 1 <= col("k2"), "left_outer") assert(df.queryExecution.sparkPlan.find( _.isInstanceOf[BroadcastNestedLoopJoinExec]).isDefined) df.noop() } ``` Seeing 2x run time improvement: ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.15.7 Intel(R) Core(TM) i9-9980HK CPU 2.40GHz left outer broadcast nested loop join: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------ left outer broadcast nested loop join wholestage off 3024 3698 953 6.9 144.2 1.0X left outer broadcast nested loop join wholestage on 1512 1659 172 13.9 72.1 2.0X ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Changed existing unit tests in `OuterJoinSuite` to cover codegen use cases. Added unit test in WholeStageCodegenSuite.scala to make sure code-gen for broadcast nested loop join is taking effect, and test for multiple join case as well. Example query: ```scala val df1 = spark.range(4).select($"id".as("k1")) val df2 = spark.range(3).select($"id".as("k2")) df1.join(df2, $"k1" + 1 <= $"k2", "left_outer").explain("codegen") ``` Example generated code (`bnlj_doConsume_0` method): ```java == Subtree 2 / 2 (maxMethodCodeSize:282; maxConstantPoolSize:210(0.32% used); numInnerClasses:0) == *(2) BroadcastNestedLoopJoin BuildRight, LeftOuter, ((k1#2L + 1) <= k2#6L) :- *(2) Project [id#0L AS k1#2L] : +- *(2) Range (0, 4, step=1, splits=16) +- BroadcastExchange IdentityBroadcastMode, [id=#22] +- *(1) Project [id#4L AS k2#6L] +- *(1) Range (0, 3, step=1, splits=16) Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage2(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=2 /* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private boolean range_initRange_0; /* 010 */ private long range_nextIndex_0; /* 011 */ private TaskContext range_taskContext_0; /* 012 */ private InputMetrics range_inputMetrics_0; /* 013 */ private long range_batchEnd_0; /* 014 */ private long range_numElementsTodo_0; /* 015 */ private InternalRow[] bnlj_buildRowArray_0; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[4]; /* 017 */ /* 018 */ public GeneratedIteratorForCodegenStage2(Object[] references) { /* 019 */ this.references = references; /* 020 */ } /* 021 */ /* 022 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 023 */ partitionIndex = index; /* 024 */ this.inputs = inputs; /* 025 */ /* 026 */ range_taskContext_0 = TaskContext.get(); /* 027 */ range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics(); /* 028 */ range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 029 */ range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 030 */ range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 031 */ bnlj_buildRowArray_0 = (InternalRow[]) ((org.apache.spark.broadcast.TorrentBroadcast) references[1] /* broadcastTerm */).value(); /* 032 */ range_mutableStateArray_0[3] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0); /* 033 */ /* 034 */ } /* 035 */ /* 036 */ private void bnlj_doConsume_0(long bnlj_expr_0_0) throws java.io.IOException { /* 037 */ boolean bnlj_foundMatch_0 = false; /* 038 */ for (int bnlj_arrayIndex_0 = 0; bnlj_arrayIndex_0 < bnlj_buildRowArray_0.length; bnlj_arrayIndex_0++) { /* 039 */ UnsafeRow bnlj_buildRow_0 = (UnsafeRow) bnlj_buildRowArray_0[bnlj_arrayIndex_0]; /* 040 */ boolean bnlj_shouldOutputRow_0 = false; /* 041 */ /* 042 */ boolean bnlj_isNull_2 = true; /* 043 */ long bnlj_value_2 = -1L; /* 044 */ if (bnlj_buildRow_0 != null) { /* 045 */ long bnlj_value_1 = bnlj_buildRow_0.getLong(0); /* 046 */ bnlj_isNull_2 = false; /* 047 */ bnlj_value_2 = bnlj_value_1; /* 048 */ } /* 049 */ /* 050 */ long bnlj_value_4 = -1L; /* 051 */ /* 052 */ bnlj_value_4 = bnlj_expr_0_0 + 1L; /* 053 */ /* 054 */ boolean bnlj_value_3 = false; /* 055 */ bnlj_value_3 = bnlj_value_4 <= bnlj_value_2; /* 056 */ if (!(false || !bnlj_value_3)) /* 057 */ { /* 058 */ bnlj_shouldOutputRow_0 = true; /* 059 */ bnlj_foundMatch_0 = true; /* 060 */ } /* 061 */ if (bnlj_arrayIndex_0 == bnlj_buildRowArray_0.length - 1 && !bnlj_foundMatch_0) { /* 062 */ bnlj_buildRow_0 = null; /* 063 */ bnlj_shouldOutputRow_0 = true; /* 064 */ } /* 065 */ if (bnlj_shouldOutputRow_0) { /* 066 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* numOutputRows */).add(1); /* 067 */ /* 068 */ boolean bnlj_isNull_9 = true; /* 069 */ long bnlj_value_9 = -1L; /* 070 */ if (bnlj_buildRow_0 != null) { /* 071 */ long bnlj_value_8 = bnlj_buildRow_0.getLong(0); /* 072 */ bnlj_isNull_9 = false; /* 073 */ bnlj_value_9 = bnlj_value_8; /* 074 */ } /* 075 */ range_mutableStateArray_0[3].reset(); /* 076 */ /* 077 */ range_mutableStateArray_0[3].zeroOutNullBytes(); /* 078 */ /* 079 */ range_mutableStateArray_0[3].write(0, bnlj_expr_0_0); /* 080 */ /* 081 */ if (bnlj_isNull_9) { /* 082 */ range_mutableStateArray_0[3].setNullAt(1); /* 083 */ } else { /* 084 */ range_mutableStateArray_0[3].write(1, bnlj_value_9); /* 085 */ } /* 086 */ append((range_mutableStateArray_0[3].getRow()).copy()); /* 087 */ /* 088 */ } /* 089 */ } /* 090 */ /* 091 */ } /* 092 */ /* 093 */ private void initRange(int idx) { /* 094 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx); /* 095 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(16L); /* 096 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(4L); /* 097 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L); /* 098 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L); /* 099 */ long partitionEnd; /* 100 */ /* 101 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); /* 102 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 103 */ range_nextIndex_0 = Long.MAX_VALUE; /* 104 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 105 */ range_nextIndex_0 = Long.MIN_VALUE; /* 106 */ } else { /* 107 */ range_nextIndex_0 = st.longValue(); /* 108 */ } /* 109 */ range_batchEnd_0 = range_nextIndex_0; /* 110 */ /* 111 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice) /* 112 */ .multiply(step).add(start); /* 113 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 114 */ partitionEnd = Long.MAX_VALUE; /* 115 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 116 */ partitionEnd = Long.MIN_VALUE; /* 117 */ } else { /* 118 */ partitionEnd = end.longValue(); /* 119 */ } /* 120 */ /* 121 */ java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract( /* 122 */ java.math.BigInteger.valueOf(range_nextIndex_0)); /* 123 */ range_numElementsTodo_0 = startToEnd.divide(step).longValue(); /* 124 */ if (range_numElementsTodo_0 < 0) { /* 125 */ range_numElementsTodo_0 = 0; /* 126 */ } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) { /* 127 */ range_numElementsTodo_0++; /* 128 */ } /* 129 */ } /* 130 */ /* 131 */ protected void processNext() throws java.io.IOException { /* 132 */ // initialize Range /* 133 */ if (!range_initRange_0) { /* 134 */ range_initRange_0 = true; /* 135 */ initRange(partitionIndex); /* 136 */ } /* 137 */ /* 138 */ while (true) { /* 139 */ if (range_nextIndex_0 == range_batchEnd_0) { /* 140 */ long range_nextBatchTodo_0; /* 141 */ if (range_numElementsTodo_0 > 1000L) { /* 142 */ range_nextBatchTodo_0 = 1000L; /* 143 */ range_numElementsTodo_0 -= 1000L; /* 144 */ } else { /* 145 */ range_nextBatchTodo_0 = range_numElementsTodo_0; /* 146 */ range_numElementsTodo_0 = 0; /* 147 */ if (range_nextBatchTodo_0 == 0) break; /* 148 */ } /* 149 */ range_batchEnd_0 += range_nextBatchTodo_0 * 1L; /* 150 */ } /* 151 */ /* 152 */ int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L); /* 153 */ for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++) { /* 154 */ long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0; /* 155 */ /* 156 */ // common sub-expressions /* 157 */ /* 158 */ bnlj_doConsume_0(range_value_0); /* 159 */ /* 160 */ if (shouldStop()) { /* 161 */ range_nextIndex_0 = range_value_0 + 1L; /* 162 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localIdx_0 + 1); /* 163 */ range_inputMetrics_0.incRecordsRead(range_localIdx_0 + 1); /* 164 */ return; /* 165 */ } /* 166 */ /* 167 */ } /* 168 */ range_nextIndex_0 = range_batchEnd_0; /* 169 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0); /* 170 */ range_inputMetrics_0.incRecordsRead(range_localEnd_0); /* 171 */ range_taskContext_0.killTaskIfInterrupted(); /* 172 */ } /* 173 */ } /* 174 */ /* 175 */ } ``` Closes #31931 from linzebing/code-left-right-outer. Authored-by: linzebing <linzebing1995@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 39542bb commit e768eaa

File tree

4 files changed

+98
-6
lines changed

4 files changed

+98
-6
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,8 @@ case class BroadcastNestedLoopJoinExec(
396396
}
397397

398398
override def supportCodegen: Boolean = (joinType, buildSide) match {
399-
case (_: InnerLike, _) | (LeftSemi | LeftAnti, BuildRight) => true
399+
case (_: InnerLike, _) | (LeftOuter, BuildRight) | (RightOuter, BuildLeft) |
400+
(LeftSemi | LeftAnti, BuildRight) => true
400401
case _ => false
401402
}
402403

@@ -413,6 +414,7 @@ case class BroadcastNestedLoopJoinExec(
413414
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
414415
(joinType, buildSide) match {
415416
case (_: InnerLike, _) => codegenInner(ctx, input)
417+
case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => codegenOuter(ctx, input)
416418
case (LeftSemi, BuildRight) => codegenLeftExistence(ctx, input, exists = true)
417419
case (LeftAnti, BuildRight) => codegenLeftExistence(ctx, input, exists = false)
418420
case _ =>
@@ -458,6 +460,49 @@ case class BroadcastNestedLoopJoinExec(
458460
""".stripMargin
459461
}
460462

463+
private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = {
464+
val (buildRowArray, buildRowArrayTerm) = prepareBroadcast(ctx)
465+
val (buildRow, checkCondition, _) = getJoinCondition(ctx, input, streamed, broadcast)
466+
val buildVars = genBuildSideVars(ctx, buildRow, broadcast)
467+
468+
val resultVars = buildSide match {
469+
case BuildLeft => buildVars ++ input
470+
case BuildRight => input ++ buildVars
471+
}
472+
val arrayIndex = ctx.freshName("arrayIndex")
473+
val shouldOutputRow = ctx.freshName("shouldOutputRow")
474+
val foundMatch = ctx.freshName("foundMatch")
475+
val numOutput = metricTerm(ctx, "numOutputRows")
476+
477+
if (buildRowArray.isEmpty) {
478+
s"""
479+
|UnsafeRow $buildRow = null;
480+
|$numOutput.add(1);
481+
|${consume(ctx, resultVars)}
482+
""".stripMargin
483+
} else {
484+
s"""
485+
|boolean $foundMatch = false;
486+
|for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) {
487+
| UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex];
488+
| boolean $shouldOutputRow = false;
489+
| $checkCondition {
490+
| $shouldOutputRow = true;
491+
| $foundMatch = true;
492+
| }
493+
| if ($arrayIndex == $buildRowArrayTerm.length - 1 && !$foundMatch) {
494+
| $buildRow = null;
495+
| $shouldOutputRow = true;
496+
| }
497+
| if ($shouldOutputRow) {
498+
| $numOutput.add(1);
499+
| ${consume(ctx, resultVars)}
500+
| }
501+
|}
502+
""".stripMargin
503+
}
504+
}
505+
461506
private def codegenLeftExistence(
462507
ctx: CodegenContext,
463508
input: Seq[ExprCode],

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,53 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
211211
}
212212
}
213213

214+
test("Left/Right outer BroadcastNestedLoopJoinExec should be included in WholeStageCodegen") {
215+
val df1 = spark.range(4).select($"id".as("k1"))
216+
val df2 = spark.range(3).select($"id".as("k2"))
217+
val df3 = spark.range(2).select($"id".as("k3"))
218+
val df4 = spark.range(0).select($"id".as("k4"))
219+
220+
Seq(true, false).foreach { codegenEnabled =>
221+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString) {
222+
// test left outer join
223+
val leftOuterJoinDF = df1.join(df2, $"k1" > $"k2", "left_outer")
224+
var hasJoinInCodegen = leftOuterJoinDF.queryExecution.executedPlan.collect {
225+
case WholeStageCodegenExec(_: BroadcastNestedLoopJoinExec) => true
226+
}.size === 1
227+
assert(hasJoinInCodegen == codegenEnabled)
228+
checkAnswer(leftOuterJoinDF,
229+
Seq(Row(0, null), Row(1, 0), Row(2, 0), Row(2, 1), Row(3, 0), Row(3, 1), Row(3, 2)))
230+
231+
// test right outer join
232+
val rightOuterJoinDF = df1.join(df2, $"k1" < $"k2", "right_outer")
233+
hasJoinInCodegen = rightOuterJoinDF.queryExecution.executedPlan.collect {
234+
case WholeStageCodegenExec(_: BroadcastNestedLoopJoinExec) => true
235+
}.size === 1
236+
assert(hasJoinInCodegen == codegenEnabled)
237+
checkAnswer(rightOuterJoinDF, Seq(Row(null, 0), Row(0, 1), Row(0, 2), Row(1, 2)))
238+
239+
// test a combination of left outer and right outer joins
240+
val twoJoinsDF = df1.join(df2, $"k1" > $"k2" + 1, "right_outer")
241+
.join(df3, $"k1" <= $"k3", "left_outer")
242+
hasJoinInCodegen = twoJoinsDF.queryExecution.executedPlan.collect {
243+
case WholeStageCodegenExec(BroadcastNestedLoopJoinExec(
244+
_: BroadcastNestedLoopJoinExec, _, _, _, _)) => true
245+
}.size === 1
246+
assert(hasJoinInCodegen == codegenEnabled)
247+
checkAnswer(twoJoinsDF,
248+
Seq(Row(2, 0, null), Row(3, 0, null), Row(3, 1, null), Row(null, 2, null)))
249+
250+
// test build side is empty
251+
val buildSideIsEmptyDF = df3.join(df4, $"k3" > $"k4", "left_outer")
252+
hasJoinInCodegen = buildSideIsEmptyDF.queryExecution.executedPlan.collect {
253+
case WholeStageCodegenExec(_: BroadcastNestedLoopJoinExec) => true
254+
}.size === 1
255+
assert(hasJoinInCodegen == codegenEnabled)
256+
checkAnswer(buildSideIsEmptyDF, Seq(Row(0, null), Row(1, null)))
257+
}
258+
}
259+
}
260+
214261
test("Left semi/anti BroadcastNestedLoopJoinExec should be included in WholeStageCodegen") {
215262
val df1 = spark.range(4).select($"id".as("k1"))
216263
val df2 = spark.range(3).select($"id".as("k2"))

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession {
149149
}
150150
}
151151

152-
test(s"$testName using BroadcastNestedLoopJoin build left") {
152+
testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastNestedLoopJoin build left") { _ =>
153153
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
154154
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
155155
BroadcastNestedLoopJoinExec(left, right, BuildLeft, joinType, Some(condition)),
@@ -158,7 +158,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession {
158158
}
159159
}
160160

161-
test(s"$testName using BroadcastNestedLoopJoin build right") {
161+
testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastNestedLoopJoin build right") { _ =>
162162
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
163163
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
164164
BroadcastNestedLoopJoinExec(left, right, BuildRight, joinType, Some(condition)),

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -452,11 +452,11 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
452452
"testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a"
453453
val rightQuery = "SELECT * FROM testData2 RIGHT JOIN testDataForJoin ON " +
454454
"testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a"
455-
Seq((leftQuery, false), (rightQuery, false), (leftQuery, true), (rightQuery, true))
456-
.foreach { case (query, enableWholeStage) =>
455+
Seq((leftQuery, 0L, false), (rightQuery, 0L, false), (leftQuery, 1L, true),
456+
(rightQuery, 1L, true)).foreach { case (query, nodeId, enableWholeStage) =>
457457
val df = spark.sql(query)
458458
testSparkPlanMetrics(df, 2, Map(
459-
0L -> (("BroadcastNestedLoopJoin", Map(
459+
nodeId -> (("BroadcastNestedLoopJoin", Map(
460460
"number of output rows" -> 12L)))),
461461
enableWholeStage
462462
)

0 commit comments

Comments
 (0)