Commit e768eaa
[SPARK-34707][SQL] Code-gen broadcast nested loop join (left outer/right outer)
### What changes were proposed in this pull request?
This PR is to add code-gen support for left outer (build right) and right outer (build left). Reference: `BroadcastNestedLoopJoinExec.codegenInner()` and `BroadcastNestedLoopJoinExec.outerJoin()`
### Why are the changes needed?
Improve query CPU performance.
Tested with a simple query:
```scala
val N = 20 << 20
val M = 1 << 4
val dim = broadcast(spark.range(M).selectExpr("id as k2"))
codegenBenchmark("left outer broadcast nested loop join", N) {
val df = spark.range(N).selectExpr(s"id as k1").join(
dim, col("k1") + 1 <= col("k2"), "left_outer")
assert(df.queryExecution.sparkPlan.find(
_.isInstanceOf[BroadcastNestedLoopJoinExec]).isDefined)
df.noop()
}
```
Seeing 2x run time improvement:
```
Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.15.7
Intel(R) Core(TM) i9-9980HK CPU 2.40GHz
left outer broadcast nested loop join: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------------------
left outer broadcast nested loop join wholestage off 3024 3698 953 6.9 144.2 1.0X
left outer broadcast nested loop join wholestage on 1512 1659 172 13.9 72.1 2.0X
```
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Changed existing unit tests in `OuterJoinSuite` to cover codegen use cases.
Added unit test in WholeStageCodegenSuite.scala to make sure code-gen for broadcast nested loop join is taking effect, and test for multiple join case as well.
Example query:
```scala
val df1 = spark.range(4).select($"id".as("k1"))
val df2 = spark.range(3).select($"id".as("k2"))
df1.join(df2, $"k1" + 1 <= $"k2", "left_outer").explain("codegen")
```
Example generated code (`bnlj_doConsume_0` method):
```java
== Subtree 2 / 2 (maxMethodCodeSize:282; maxConstantPoolSize:210(0.32% used); numInnerClasses:0) ==
*(2) BroadcastNestedLoopJoin BuildRight, LeftOuter, ((k1#2L + 1) <= k2#6L)
:- *(2) Project [id#0L AS k1#2L]
: +- *(2) Range (0, 4, step=1, splits=16)
+- BroadcastExchange IdentityBroadcastMode, [id=#22]
+- *(1) Project [id#4L AS k2#6L]
+- *(1) Range (0, 3, step=1, splits=16)
Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIteratorForCodegenStage2(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=2
/* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */ private Object[] references;
/* 008 */ private scala.collection.Iterator[] inputs;
/* 009 */ private boolean range_initRange_0;
/* 010 */ private long range_nextIndex_0;
/* 011 */ private TaskContext range_taskContext_0;
/* 012 */ private InputMetrics range_inputMetrics_0;
/* 013 */ private long range_batchEnd_0;
/* 014 */ private long range_numElementsTodo_0;
/* 015 */ private InternalRow[] bnlj_buildRowArray_0;
/* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[4];
/* 017 */
/* 018 */ public GeneratedIteratorForCodegenStage2(Object[] references) {
/* 019 */ this.references = references;
/* 020 */ }
/* 021 */
/* 022 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 023 */ partitionIndex = index;
/* 024 */ this.inputs = inputs;
/* 025 */
/* 026 */ range_taskContext_0 = TaskContext.get();
/* 027 */ range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
/* 028 */ range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 029 */ range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 030 */ range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 031 */ bnlj_buildRowArray_0 = (InternalRow[]) ((org.apache.spark.broadcast.TorrentBroadcast) references[1] /* broadcastTerm */).value();
/* 032 */ range_mutableStateArray_0[3] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
/* 033 */
/* 034 */ }
/* 035 */
/* 036 */ private void bnlj_doConsume_0(long bnlj_expr_0_0) throws java.io.IOException {
/* 037 */ boolean bnlj_foundMatch_0 = false;
/* 038 */ for (int bnlj_arrayIndex_0 = 0; bnlj_arrayIndex_0 < bnlj_buildRowArray_0.length; bnlj_arrayIndex_0++) {
/* 039 */ UnsafeRow bnlj_buildRow_0 = (UnsafeRow) bnlj_buildRowArray_0[bnlj_arrayIndex_0];
/* 040 */ boolean bnlj_shouldOutputRow_0 = false;
/* 041 */
/* 042 */ boolean bnlj_isNull_2 = true;
/* 043 */ long bnlj_value_2 = -1L;
/* 044 */ if (bnlj_buildRow_0 != null) {
/* 045 */ long bnlj_value_1 = bnlj_buildRow_0.getLong(0);
/* 046 */ bnlj_isNull_2 = false;
/* 047 */ bnlj_value_2 = bnlj_value_1;
/* 048 */ }
/* 049 */
/* 050 */ long bnlj_value_4 = -1L;
/* 051 */
/* 052 */ bnlj_value_4 = bnlj_expr_0_0 + 1L;
/* 053 */
/* 054 */ boolean bnlj_value_3 = false;
/* 055 */ bnlj_value_3 = bnlj_value_4 <= bnlj_value_2;
/* 056 */ if (!(false || !bnlj_value_3))
/* 057 */ {
/* 058 */ bnlj_shouldOutputRow_0 = true;
/* 059 */ bnlj_foundMatch_0 = true;
/* 060 */ }
/* 061 */ if (bnlj_arrayIndex_0 == bnlj_buildRowArray_0.length - 1 && !bnlj_foundMatch_0) {
/* 062 */ bnlj_buildRow_0 = null;
/* 063 */ bnlj_shouldOutputRow_0 = true;
/* 064 */ }
/* 065 */ if (bnlj_shouldOutputRow_0) {
/* 066 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* numOutputRows */).add(1);
/* 067 */
/* 068 */ boolean bnlj_isNull_9 = true;
/* 069 */ long bnlj_value_9 = -1L;
/* 070 */ if (bnlj_buildRow_0 != null) {
/* 071 */ long bnlj_value_8 = bnlj_buildRow_0.getLong(0);
/* 072 */ bnlj_isNull_9 = false;
/* 073 */ bnlj_value_9 = bnlj_value_8;
/* 074 */ }
/* 075 */ range_mutableStateArray_0[3].reset();
/* 076 */
/* 077 */ range_mutableStateArray_0[3].zeroOutNullBytes();
/* 078 */
/* 079 */ range_mutableStateArray_0[3].write(0, bnlj_expr_0_0);
/* 080 */
/* 081 */ if (bnlj_isNull_9) {
/* 082 */ range_mutableStateArray_0[3].setNullAt(1);
/* 083 */ } else {
/* 084 */ range_mutableStateArray_0[3].write(1, bnlj_value_9);
/* 085 */ }
/* 086 */ append((range_mutableStateArray_0[3].getRow()).copy());
/* 087 */
/* 088 */ }
/* 089 */ }
/* 090 */
/* 091 */ }
/* 092 */
/* 093 */ private void initRange(int idx) {
/* 094 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 095 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(16L);
/* 096 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(4L);
/* 097 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 098 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 099 */ long partitionEnd;
/* 100 */
/* 101 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 102 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 103 */ range_nextIndex_0 = Long.MAX_VALUE;
/* 104 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 105 */ range_nextIndex_0 = Long.MIN_VALUE;
/* 106 */ } else {
/* 107 */ range_nextIndex_0 = st.longValue();
/* 108 */ }
/* 109 */ range_batchEnd_0 = range_nextIndex_0;
/* 110 */
/* 111 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 112 */ .multiply(step).add(start);
/* 113 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 114 */ partitionEnd = Long.MAX_VALUE;
/* 115 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 116 */ partitionEnd = Long.MIN_VALUE;
/* 117 */ } else {
/* 118 */ partitionEnd = end.longValue();
/* 119 */ }
/* 120 */
/* 121 */ java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 122 */ java.math.BigInteger.valueOf(range_nextIndex_0));
/* 123 */ range_numElementsTodo_0 = startToEnd.divide(step).longValue();
/* 124 */ if (range_numElementsTodo_0 < 0) {
/* 125 */ range_numElementsTodo_0 = 0;
/* 126 */ } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 127 */ range_numElementsTodo_0++;
/* 128 */ }
/* 129 */ }
/* 130 */
/* 131 */ protected void processNext() throws java.io.IOException {
/* 132 */ // initialize Range
/* 133 */ if (!range_initRange_0) {
/* 134 */ range_initRange_0 = true;
/* 135 */ initRange(partitionIndex);
/* 136 */ }
/* 137 */
/* 138 */ while (true) {
/* 139 */ if (range_nextIndex_0 == range_batchEnd_0) {
/* 140 */ long range_nextBatchTodo_0;
/* 141 */ if (range_numElementsTodo_0 > 1000L) {
/* 142 */ range_nextBatchTodo_0 = 1000L;
/* 143 */ range_numElementsTodo_0 -= 1000L;
/* 144 */ } else {
/* 145 */ range_nextBatchTodo_0 = range_numElementsTodo_0;
/* 146 */ range_numElementsTodo_0 = 0;
/* 147 */ if (range_nextBatchTodo_0 == 0) break;
/* 148 */ }
/* 149 */ range_batchEnd_0 += range_nextBatchTodo_0 * 1L;
/* 150 */ }
/* 151 */
/* 152 */ int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L);
/* 153 */ for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++) {
/* 154 */ long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0;
/* 155 */
/* 156 */ // common sub-expressions
/* 157 */
/* 158 */ bnlj_doConsume_0(range_value_0);
/* 159 */
/* 160 */ if (shouldStop()) {
/* 161 */ range_nextIndex_0 = range_value_0 + 1L;
/* 162 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localIdx_0 + 1);
/* 163 */ range_inputMetrics_0.incRecordsRead(range_localIdx_0 + 1);
/* 164 */ return;
/* 165 */ }
/* 166 */
/* 167 */ }
/* 168 */ range_nextIndex_0 = range_batchEnd_0;
/* 169 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0);
/* 170 */ range_inputMetrics_0.incRecordsRead(range_localEnd_0);
/* 171 */ range_taskContext_0.killTaskIfInterrupted();
/* 172 */ }
/* 173 */ }
/* 174 */
/* 175 */ }
```
Closes #31931 from linzebing/code-left-right-outer.
Authored-by: linzebing <linzebing1995@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>1 parent 39542bb commit e768eaa
File tree
4 files changed
+98
-6
lines changed- sql/core/src
- main/scala/org/apache/spark/sql/execution/joins
- test/scala/org/apache/spark/sql/execution
- joins
- metric
4 files changed
+98
-6
lines changedLines changed: 46 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
396 | 396 | | |
397 | 397 | | |
398 | 398 | | |
399 | | - | |
| 399 | + | |
| 400 | + | |
400 | 401 | | |
401 | 402 | | |
402 | 403 | | |
| |||
413 | 414 | | |
414 | 415 | | |
415 | 416 | | |
| 417 | + | |
416 | 418 | | |
417 | 419 | | |
418 | 420 | | |
| |||
458 | 460 | | |
459 | 461 | | |
460 | 462 | | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
| 481 | + | |
| 482 | + | |
| 483 | + | |
| 484 | + | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
| 490 | + | |
| 491 | + | |
| 492 | + | |
| 493 | + | |
| 494 | + | |
| 495 | + | |
| 496 | + | |
| 497 | + | |
| 498 | + | |
| 499 | + | |
| 500 | + | |
| 501 | + | |
| 502 | + | |
| 503 | + | |
| 504 | + | |
| 505 | + | |
461 | 506 | | |
462 | 507 | | |
463 | 508 | | |
| |||
Lines changed: 47 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
211 | 211 | | |
212 | 212 | | |
213 | 213 | | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
214 | 261 | | |
215 | 262 | | |
216 | 263 | | |
| |||
Lines changed: 2 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
149 | 149 | | |
150 | 150 | | |
151 | 151 | | |
152 | | - | |
| 152 | + | |
153 | 153 | | |
154 | 154 | | |
155 | 155 | | |
| |||
158 | 158 | | |
159 | 159 | | |
160 | 160 | | |
161 | | - | |
| 161 | + | |
162 | 162 | | |
163 | 163 | | |
164 | 164 | | |
| |||
Lines changed: 3 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
452 | 452 | | |
453 | 453 | | |
454 | 454 | | |
455 | | - | |
456 | | - | |
| 455 | + | |
| 456 | + | |
457 | 457 | | |
458 | 458 | | |
459 | | - | |
| 459 | + | |
460 | 460 | | |
461 | 461 | | |
462 | 462 | | |
| |||
0 commit comments