Skip to content

[SPARK-34707][SQL] Code-gen broadcast nested loop join (left outer/right outer) #31931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

linzebing
Copy link
Contributor

What changes were proposed in this pull request?

This PR is to add code-gen support for left outer (build right) and right outer (build left). Reference: BroadcastNestedLoopJoinExec.codegenInner() and BroadcastNestedLoopJoinExec.outerJoin()

Why are the changes needed?

Improve query CPU performance.
Tested with a simple query:

val N = 20 << 20
val M = 1 << 4

val dim = broadcast(spark.range(M).selectExpr("id as k2"))
codegenBenchmark("left outer broadcast nested loop join", N) {
   val df = spark.range(N).selectExpr(s"id as k1").join(
     dim, col("k1") + 1 <= col("k2"), "left_outer")
   assert(df.queryExecution.sparkPlan.find(
     _.isInstanceOf[BroadcastNestedLoopJoinExec]).isDefined)
   df.noop()
}

Seeing 2x run time improvement:

Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.15.7
Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz
left outer broadcast nested loop join:                Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------------------------------------------
left outer broadcast nested loop join wholestage off           3024           3698         953          6.9         144.2       1.0X
left outer broadcast nested loop join wholestage on            1512           1659         172         13.9          72.1       2.0X

Does this PR introduce any user-facing change?

No

How was this patch tested?

Changed existing unit tests in OuterJoinSuite to cover codegen use cases.
Added unit test in WholeStageCodegenSuite.scala to make sure code-gen for broadcast nested loop join is taking effect, and test for multiple join case as well.

Example query:

val df1 = spark.range(4).select($"id".as("k1"))
val df2 = spark.range(3).select($"id".as("k2"))
df1.join(df2, $"k1" + 1 <= $"k2", "left_outer").explain("codegen")

Example generated code (bnlj_doConsume_0 method):

== Subtree 2 / 2 (maxMethodCodeSize:282; maxConstantPoolSize:210(0.32% used); numInnerClasses:0) ==
*(2) BroadcastNestedLoopJoin BuildRight, LeftOuter, ((k1#2L + 1) <= k2#6L)
:- *(2) Project [id#0L AS k1#2L]
:  +- *(2) Range (0, 4, step=1, splits=16)
+- BroadcastExchange IdentityBroadcastMode, [id=#22]
   +- *(1) Project [id#4L AS k2#6L]
      +- *(1) Range (0, 3, step=1, splits=16)

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage2(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=2
/* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean range_initRange_0;
/* 010 */   private long range_nextIndex_0;
/* 011 */   private TaskContext range_taskContext_0;
/* 012 */   private InputMetrics range_inputMetrics_0;
/* 013 */   private long range_batchEnd_0;
/* 014 */   private long range_numElementsTodo_0;
/* 015 */   private InternalRow[] bnlj_buildRowArray_0;
/* 016 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[4];
/* 017 */
/* 018 */   public GeneratedIteratorForCodegenStage2(Object[] references) {
/* 019 */     this.references = references;
/* 020 */   }
/* 021 */
/* 022 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 023 */     partitionIndex = index;
/* 024 */     this.inputs = inputs;
/* 025 */
/* 026 */     range_taskContext_0 = TaskContext.get();
/* 027 */     range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
/* 028 */     range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 029 */     range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 030 */     range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 031 */     bnlj_buildRowArray_0 = (InternalRow[]) ((org.apache.spark.broadcast.TorrentBroadcast) references[1] /* broadcastTerm */).value();
/* 032 */     range_mutableStateArray_0[3] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
/* 033 */
/* 034 */   }
/* 035 */
/* 036 */   private void bnlj_doConsume_0(long bnlj_expr_0_0) throws java.io.IOException {
/* 037 */     boolean bnlj_foundMatch_0 = false;
/* 038 */     for (int bnlj_arrayIndex_0 = 0; bnlj_arrayIndex_0 < bnlj_buildRowArray_0.length; bnlj_arrayIndex_0++) {
/* 039 */       UnsafeRow bnlj_buildRow_0 = (UnsafeRow) bnlj_buildRowArray_0[bnlj_arrayIndex_0];
/* 040 */       boolean bnlj_shouldOutputRow_0 = false;
/* 041 */
/* 042 */       boolean bnlj_isNull_2 = true;
/* 043 */       long bnlj_value_2 = -1L;
/* 044 */       if (bnlj_buildRow_0 != null) {
/* 045 */         long bnlj_value_1 = bnlj_buildRow_0.getLong(0);
/* 046 */         bnlj_isNull_2 = false;
/* 047 */         bnlj_value_2 = bnlj_value_1;
/* 048 */       }
/* 049 */
/* 050 */       long bnlj_value_4 = -1L;
/* 051 */
/* 052 */       bnlj_value_4 = bnlj_expr_0_0 + 1L;
/* 053 */
/* 054 */       boolean bnlj_value_3 = false;
/* 055 */       bnlj_value_3 = bnlj_value_4 <= bnlj_value_2;
/* 056 */       if (!(false || !bnlj_value_3))
/* 057 */       {
/* 058 */         bnlj_shouldOutputRow_0 = true;
/* 059 */         bnlj_foundMatch_0 = true;
/* 060 */       }
/* 061 */       if (bnlj_arrayIndex_0 == bnlj_buildRowArray_0.length - 1 && !bnlj_foundMatch_0) {
/* 062 */         bnlj_buildRow_0 = null;
/* 063 */         bnlj_shouldOutputRow_0 = true;
/* 064 */       }
/* 065 */       if (bnlj_shouldOutputRow_0) {
/* 066 */         ((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* numOutputRows */).add(1);
/* 067 */
/* 068 */         boolean bnlj_isNull_9 = true;
/* 069 */         long bnlj_value_9 = -1L;
/* 070 */         if (bnlj_buildRow_0 != null) {
/* 071 */           long bnlj_value_8 = bnlj_buildRow_0.getLong(0);
/* 072 */           bnlj_isNull_9 = false;
/* 073 */           bnlj_value_9 = bnlj_value_8;
/* 074 */         }
/* 075 */         range_mutableStateArray_0[3].reset();
/* 076 */
/* 077 */         range_mutableStateArray_0[3].zeroOutNullBytes();
/* 078 */
/* 079 */         range_mutableStateArray_0[3].write(0, bnlj_expr_0_0);
/* 080 */
/* 081 */         if (bnlj_isNull_9) {
/* 082 */           range_mutableStateArray_0[3].setNullAt(1);
/* 083 */         } else {
/* 084 */           range_mutableStateArray_0[3].write(1, bnlj_value_9);
/* 085 */         }
/* 086 */         append((range_mutableStateArray_0[3].getRow()).copy());
/* 087 */
/* 088 */       }
/* 089 */     }
/* 090 */
/* 091 */   }
/* 092 */
/* 093 */   private void initRange(int idx) {
/* 094 */     java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 095 */     java.math.BigInteger numSlice = java.math.BigInteger.valueOf(16L);
/* 096 */     java.math.BigInteger numElement = java.math.BigInteger.valueOf(4L);
/* 097 */     java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 098 */     java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 099 */     long partitionEnd;
/* 100 */
/* 101 */     java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 102 */     if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 103 */       range_nextIndex_0 = Long.MAX_VALUE;
/* 104 */     } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 105 */       range_nextIndex_0 = Long.MIN_VALUE;
/* 106 */     } else {
/* 107 */       range_nextIndex_0 = st.longValue();
/* 108 */     }
/* 109 */     range_batchEnd_0 = range_nextIndex_0;
/* 110 */
/* 111 */     java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 112 */     .multiply(step).add(start);
/* 113 */     if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 114 */       partitionEnd = Long.MAX_VALUE;
/* 115 */     } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 116 */       partitionEnd = Long.MIN_VALUE;
/* 117 */     } else {
/* 118 */       partitionEnd = end.longValue();
/* 119 */     }
/* 120 */
/* 121 */     java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 122 */       java.math.BigInteger.valueOf(range_nextIndex_0));
/* 123 */     range_numElementsTodo_0  = startToEnd.divide(step).longValue();
/* 124 */     if (range_numElementsTodo_0 < 0) {
/* 125 */       range_numElementsTodo_0 = 0;
/* 126 */     } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 127 */       range_numElementsTodo_0++;
/* 128 */     }
/* 129 */   }
/* 130 */
/* 131 */   protected void processNext() throws java.io.IOException {
/* 132 */     // initialize Range
/* 133 */     if (!range_initRange_0) {
/* 134 */       range_initRange_0 = true;
/* 135 */       initRange(partitionIndex);
/* 136 */     }
/* 137 */
/* 138 */     while (true) {
/* 139 */       if (range_nextIndex_0 == range_batchEnd_0) {
/* 140 */         long range_nextBatchTodo_0;
/* 141 */         if (range_numElementsTodo_0 > 1000L) {
/* 142 */           range_nextBatchTodo_0 = 1000L;
/* 143 */           range_numElementsTodo_0 -= 1000L;
/* 144 */         } else {
/* 145 */           range_nextBatchTodo_0 = range_numElementsTodo_0;
/* 146 */           range_numElementsTodo_0 = 0;
/* 147 */           if (range_nextBatchTodo_0 == 0) break;
/* 148 */         }
/* 149 */         range_batchEnd_0 += range_nextBatchTodo_0 * 1L;
/* 150 */       }
/* 151 */
/* 152 */       int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L);
/* 153 */       for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++) {
/* 154 */         long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0;
/* 155 */
/* 156 */         // common sub-expressions
/* 157 */
/* 158 */         bnlj_doConsume_0(range_value_0);
/* 159 */
/* 160 */         if (shouldStop()) {
/* 161 */           range_nextIndex_0 = range_value_0 + 1L;
/* 162 */           ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localIdx_0 + 1);
/* 163 */           range_inputMetrics_0.incRecordsRead(range_localIdx_0 + 1);
/* 164 */           return;
/* 165 */         }
/* 166 */
/* 167 */       }
/* 168 */       range_nextIndex_0 = range_batchEnd_0;
/* 169 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0);
/* 170 */       range_inputMetrics_0.incRecordsRead(range_localEnd_0);
/* 171 */       range_taskContext_0.killTaskIfInterrupted();
/* 172 */     }
/* 173 */   }
/* 174 */
/* 175 */ }

@c21
Copy link
Contributor

c21 commented Mar 22, 2021

Thanks @linzebing for working on this. Can you add [SQL] to the PR title as well?

@c21
Copy link
Contributor

c21 commented Mar 22, 2021

I am looking this, also cc @cloud-fan and @maropu for review if you have time, thanks.

@c21
Copy link
Contributor

c21 commented Mar 22, 2021

Also the JIRA should be SPARK-34707, not SPARK-34706. https://issues.apache.org/jira/browse/SPARK-34706 is the umbrella task for all BNLJ improvement.

@linzebing linzebing changed the title [SPARK-34706] Code-gen broadcast nested loop join (left outer/right outer) [SPARK-34707][SQL] Code-gen broadcast nested loop join (left outer/right outer) Mar 22, 2021
@github-actions github-actions bot added the SQL label Mar 22, 2021
@linzebing
Copy link
Contributor Author

One unit test in SQLMetricsSuite failed. Let me check.

@maropu
Copy link
Member

maropu commented Mar 22, 2021

ok to test

@maropu
Copy link
Member

maropu commented Mar 22, 2021

I'll check it later, too. Thanks for your work, @linzebing !

@linzebing linzebing force-pushed the code-left-right-outer branch from e7f119b to 08e47d2 Compare March 22, 2021 22:51
@linzebing
Copy link
Contributor Author

One unit test in SQLMetricsSuite failed. Let me check.

Fixed test.


s"""
|boolean $foundMatch = false;
|for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the broadcast side to be empty? It seems not right here because we still need to output one row for streamed side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right. Let me address this.

.join(df3, $"k1" <= $"k3", "left_outer")
hasJoinInCodegen = twoJoinsDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(BroadcastNestedLoopJoinExec(
_: BroadcastNestedLoopJoinExec, _, _, _, _)) => true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: seems indentation is off.


Seq(true, false).foreach { codegenEnabled =>
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString) {
// test left outer join
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about adding an extra test case for broadcast side being empty?

@SparkQA
Copy link

SparkQA commented Mar 23, 2021

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/40951/

@SparkQA
Copy link

SparkQA commented Mar 23, 2021

Kubernetes integration test status success
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/40951/

@SparkQA
Copy link

SparkQA commented Mar 23, 2021

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/40957/

@SparkQA
Copy link

SparkQA commented Mar 23, 2021

Kubernetes integration test status failure
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/40957/

@SparkQA
Copy link

SparkQA commented Mar 23, 2021

Test build #136367 has finished for PR 31931 at commit 08e47d2.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 23, 2021

Test build #136374 has finished for PR 31931 at commit 8ee7536.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@cloud-fan
Copy link
Contributor

thanks, merging to master!

@cloud-fan cloud-fan closed this in e768eaa Mar 23, 2021
@maropu
Copy link
Member

maropu commented Mar 23, 2021

late lgtm.

Copy link
Contributor

@c21 c21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Late LGTM too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants