Commit b5b1985
[SPARK-34620][SQL] Code-gen broadcast nested loop join (inner/cross)
### What changes were proposed in this pull request?
`BroadcastNestedLoopJoinExec` does not have code-gen, and we can potentially boost the CPU performance for this operator if we add code-gen for it. https://databricks.com/blog/2017/02/16/processing-trillion-rows-per-second-single-machine-can-nested-loop-joins-fast.html also showed the evidence in one fork.
The codegen for `BroadcastNestedLoopJoinExec` shared some code with `HashJoin`, and the interface `JoinCodegenSupport` is created to hold those common logic. This PR is only supporting inner and cross join. Other join types will be added later in followup PRs.
Example query and generated code:
```
val df1 = spark.range(4).select($"id".as("k1"))
val df2 = spark.range(3).select($"id".as("k2"))
df1.join(df2, $"k1" + 1 =!= $"k2").explain("codegen")
```
```
== Subtree 2 / 2 (maxMethodCodeSize:282; maxConstantPoolSize:203(0.31% used); numInnerClasses:0) ==
*(2) BroadcastNestedLoopJoin BuildRight, Inner, NOT ((k1#2L + 1) = k2#6L)
:- *(2) Project [id#0L AS k1#2L]
: +- *(2) Range (0, 4, step=1, splits=2)
+- BroadcastExchange IdentityBroadcastMode, [id=#22]
+- *(1) Project [id#4L AS k2#6L]
+- *(1) Range (0, 3, step=1, splits=2)
Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIteratorForCodegenStage2(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=2
/* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */ private Object[] references;
/* 008 */ private scala.collection.Iterator[] inputs;
/* 009 */ private boolean range_initRange_0;
/* 010 */ private long range_nextIndex_0;
/* 011 */ private TaskContext range_taskContext_0;
/* 012 */ private InputMetrics range_inputMetrics_0;
/* 013 */ private long range_batchEnd_0;
/* 014 */ private long range_numElementsTodo_0;
/* 015 */ private InternalRow[] bnlj_buildRowArray_0;
/* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[4];
/* 017 */
/* 018 */ public GeneratedIteratorForCodegenStage2(Object[] references) {
/* 019 */ this.references = references;
/* 020 */ }
/* 021 */
/* 022 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 023 */ partitionIndex = index;
/* 024 */ this.inputs = inputs;
/* 025 */
/* 026 */ range_taskContext_0 = TaskContext.get();
/* 027 */ range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
/* 028 */ range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 029 */ range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 030 */ range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 031 */ bnlj_buildRowArray_0 = (InternalRow[]) ((org.apache.spark.broadcast.TorrentBroadcast) references[1] /* broadcastTerm */).value();
/* 032 */ range_mutableStateArray_0[3] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
/* 033 */
/* 034 */ }
/* 035 */
/* 036 */ private void bnlj_doConsume_0(long bnlj_expr_0_0) throws java.io.IOException {
/* 037 */ for (int bnlj_arrayIndex_0 = 0; bnlj_arrayIndex_0 < bnlj_buildRowArray_0.length; bnlj_arrayIndex_0++) {
/* 038 */ UnsafeRow bnlj_buildRow_0 = (UnsafeRow) bnlj_buildRowArray_0[bnlj_arrayIndex_0];
/* 039 */
/* 040 */ long bnlj_value_1 = bnlj_buildRow_0.getLong(0);
/* 041 */
/* 042 */ long bnlj_value_4 = -1L;
/* 043 */
/* 044 */ bnlj_value_4 = bnlj_expr_0_0 + 1L;
/* 045 */
/* 046 */ boolean bnlj_value_3 = false;
/* 047 */ bnlj_value_3 = bnlj_value_4 == bnlj_value_1;
/* 048 */ boolean bnlj_value_2 = false;
/* 049 */ bnlj_value_2 = !(bnlj_value_3);
/* 050 */ if (!(false || !bnlj_value_2))
/* 051 */ {
/* 052 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* numOutputRows */).add(1);
/* 053 */
/* 054 */ range_mutableStateArray_0[3].reset();
/* 055 */
/* 056 */ range_mutableStateArray_0[3].write(0, bnlj_expr_0_0);
/* 057 */
/* 058 */ range_mutableStateArray_0[3].write(1, bnlj_value_1);
/* 059 */ append((range_mutableStateArray_0[3].getRow()).copy());
/* 060 */
/* 061 */ }
/* 062 */ }
/* 063 */
/* 064 */ }
/* 065 */
/* 066 */ private void initRange(int idx) {
/* 067 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 068 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L);
/* 069 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(4L);
/* 070 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 071 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 072 */ long partitionEnd;
/* 073 */
/* 074 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 075 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 076 */ range_nextIndex_0 = Long.MAX_VALUE;
/* 077 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 078 */ range_nextIndex_0 = Long.MIN_VALUE;
/* 079 */ } else {
/* 080 */ range_nextIndex_0 = st.longValue();
/* 081 */ }
/* 082 */ range_batchEnd_0 = range_nextIndex_0;
/* 083 */
/* 084 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 085 */ .multiply(step).add(start);
/* 086 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 087 */ partitionEnd = Long.MAX_VALUE;
/* 088 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 089 */ partitionEnd = Long.MIN_VALUE;
/* 090 */ } else {
/* 091 */ partitionEnd = end.longValue();
/* 092 */ }
/* 093 */
/* 094 */ java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 095 */ java.math.BigInteger.valueOf(range_nextIndex_0));
/* 096 */ range_numElementsTodo_0 = startToEnd.divide(step).longValue();
/* 097 */ if (range_numElementsTodo_0 < 0) {
/* 098 */ range_numElementsTodo_0 = 0;
/* 099 */ } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 100 */ range_numElementsTodo_0++;
/* 101 */ }
/* 102 */ }
/* 103 */
/* 104 */ protected void processNext() throws java.io.IOException {
/* 105 */ // initialize Range
/* 106 */ if (!range_initRange_0) {
/* 107 */ range_initRange_0 = true;
/* 108 */ initRange(partitionIndex);
/* 109 */ }
/* 110 */
/* 111 */ while (true) {
/* 112 */ if (range_nextIndex_0 == range_batchEnd_0) {
/* 113 */ long range_nextBatchTodo_0;
/* 114 */ if (range_numElementsTodo_0 > 1000L) {
/* 115 */ range_nextBatchTodo_0 = 1000L;
/* 116 */ range_numElementsTodo_0 -= 1000L;
/* 117 */ } else {
/* 118 */ range_nextBatchTodo_0 = range_numElementsTodo_0;
/* 119 */ range_numElementsTodo_0 = 0;
/* 120 */ if (range_nextBatchTodo_0 == 0) break;
/* 121 */ }
/* 122 */ range_batchEnd_0 += range_nextBatchTodo_0 * 1L;
/* 123 */ }
/* 124 */
/* 125 */ int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L);
/* 126 */ for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++) {
/* 127 */ long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0;
/* 128 */
/* 129 */ // common sub-expressions
/* 130 */
/* 131 */ bnlj_doConsume_0(range_value_0);
/* 132 */
/* 133 */ if (shouldStop()) {
/* 134 */ range_nextIndex_0 = range_value_0 + 1L;
/* 135 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localIdx_0 + 1);
/* 136 */ range_inputMetrics_0.incRecordsRead(range_localIdx_0 + 1);
/* 137 */ return;
/* 138 */ }
/* 139 */
/* 140 */ }
/* 141 */ range_nextIndex_0 = range_batchEnd_0;
/* 142 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0);
/* 143 */ range_inputMetrics_0.incRecordsRead(range_localEnd_0);
/* 144 */ range_taskContext_0.killTaskIfInterrupted();
/* 145 */ }
/* 146 */ }
/* 147 */
/* 148 */ }
```
### Why are the changes needed?
Improve query CPU performance. Added a micro benchmark query in `JoinBenchmark.scala`.
Saw 1x of run time improvement:
```
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 2.50GHz
broadcast nested loop join: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------
broadcast nested loop join wholestage off 62922 63052 184 0.3 3000.3 1.0X
broadcast nested loop join wholestage on 30946 30972 26 0.7 1475.6 2.0X
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
* Added unit test in `WholeStageCodegenSuite.scala`, and existing unit tests for `BroadcastNestedLoopJoinExec`.
* Updated golden files for several TCPDS query plans, as whole stage code-gen for `BroadcastNestedLoopJoinExec` is triggered.
* Updated `JoinBenchmark-jdk11-results.txt ` and `JoinBenchmark-results.txt` with new benchmark result. Followed previous benchmark PRs - #27078 and #26003 to use same type of machine:
```
Amazon AWS EC2
type: r3.xlarge
region: us-west-2 (Oregon)
OS: Linux
```
Closes #31736 from c21/nested-join-exec.
Authored-by: Cheng Su <chengsu@fb.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>1 parent 43b23fd commit b5b1985
File tree
36 files changed
+1557
-1378
lines changed- sql/core
- benchmarks
- src
- main/scala/org/apache/spark/sql/execution
- joins
- test
- resources/tpcds-plan-stability
- approved-plans-v1_4
- q28.sf100
- q28
- q61.sf100
- q61
- q77.sf100
- q77
- q88.sf100
- q88
- q90.sf100
- q90
- approved-plans-v2_7
- q22.sf100
- q22
- q77a.sf100
- q77a
- scala/org/apache/spark/sql/execution
- benchmark
36 files changed
+1557
-1378
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
5 | | - | |
| 5 | + | |
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
9 | | - | |
10 | | - | |
| 9 | + | |
| 10 | + | |
11 | 11 | | |
12 | | - | |
| 12 | + | |
13 | 13 | | |
14 | 14 | | |
15 | 15 | | |
16 | | - | |
17 | | - | |
| 16 | + | |
| 17 | + | |
18 | 18 | | |
19 | | - | |
| 19 | + | |
20 | 20 | | |
21 | 21 | | |
22 | 22 | | |
23 | | - | |
24 | | - | |
| 23 | + | |
| 24 | + | |
25 | 25 | | |
26 | | - | |
| 26 | + | |
27 | 27 | | |
28 | 28 | | |
29 | 29 | | |
30 | | - | |
31 | | - | |
| 30 | + | |
| 31 | + | |
32 | 32 | | |
33 | | - | |
| 33 | + | |
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
37 | | - | |
38 | | - | |
| 37 | + | |
| 38 | + | |
39 | 39 | | |
40 | | - | |
| 40 | + | |
41 | 41 | | |
42 | 42 | | |
43 | 43 | | |
44 | | - | |
45 | | - | |
| 44 | + | |
| 45 | + | |
46 | 46 | | |
47 | | - | |
| 47 | + | |
48 | 48 | | |
49 | 49 | | |
50 | 50 | | |
51 | | - | |
52 | | - | |
| 51 | + | |
| 52 | + | |
53 | 53 | | |
54 | | - | |
| 54 | + | |
55 | 55 | | |
56 | 56 | | |
57 | 57 | | |
58 | | - | |
59 | | - | |
| 58 | + | |
| 59 | + | |
60 | 60 | | |
61 | | - | |
| 61 | + | |
62 | 62 | | |
63 | | - | |
64 | | - | |
65 | | - | |
66 | | - | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
67 | 67 | | |
68 | | - | |
| 68 | + | |
69 | 69 | | |
70 | 70 | | |
71 | 71 | | |
72 | | - | |
73 | | - | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
74 | 81 | | |
75 | 82 | | |
0 commit comments