Skip to content

Commit eba2076

Browse files
Eric5553cloud-fan
authored andcommitted
[SPARK-30842][SQL] Adjust abstraction structure for join operators
### What changes were proposed in this pull request? Currently the join operators are not well abstracted, since there are lot of common logic. A trait can be created for easier pattern matching and other future handiness. This is a follow-up PR based on comment #27509 (comment) . This PR refined from the following aspects: 1. Refined structure of all physical join operators 2. Add missing joinType field for CartesianProductExec operator 3. Refined codes related to Explain Formatted The EXPLAIN FORMATTED changes are 1. Converge all join operator `verboseStringWithOperatorId` implementations to `BaseJoinExec`. Join condition displayed, and join keys displayed if it’s not empty. 2. `#1` will add Join condition to `BroadcastNestedLoopJoinExec`. 3. `#1` will **NOT** affect `CartesianProductExec`,`SortMergeJoin` and `HashJoin`s, since they already got there override implementation before. 4. Converge all join operator `simpleStringWithNodeId` to `BaseJoinExec`, which will enhance the one line description for `CartesianProductExec` with `JoinType` added. 5. Override `simpleStringWithNodeId` in `BroadcastNestedLoopJoinExec` to show `BuildSide`, which was only done for `HashJoin`s before. ### Why are the changes needed? Make the code consistent with other operators and for future handiness of join operators. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Existing tests Closes #27595 from Eric5553/RefineJoin. Authored-by: Eric Wu <492960551@qq.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 14bb639 commit eba2076

File tree

7 files changed

+81
-60
lines changed

7 files changed

+81
-60
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.joins
19+
20+
import org.apache.spark.sql.catalyst.expressions.Expression
21+
import org.apache.spark.sql.catalyst.plans.JoinType
22+
import org.apache.spark.sql.execution.{BinaryExecNode, ExplainUtils}
23+
24+
/**
25+
* Holds common logic for join operators
26+
*/
27+
trait BaseJoinExec extends BinaryExecNode {
28+
def joinType: JoinType
29+
def condition: Option[Expression]
30+
def leftKeys: Seq[Expression]
31+
def rightKeys: Seq[Expression]
32+
33+
override def simpleStringWithNodeId(): String = {
34+
val opId = ExplainUtils.getOpId(this)
35+
s"$nodeName $joinType ($opId)".trim
36+
}
37+
38+
override def verboseStringWithOperatorId(): String = {
39+
val joinCondStr = if (condition.isDefined) {
40+
s"${condition.get}"
41+
} else "None"
42+
if (leftKeys.nonEmpty || rightKeys.nonEmpty) {
43+
s"""
44+
|(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)}
45+
|${ExplainUtils.generateFieldString("Left keys", leftKeys)}
46+
|${ExplainUtils.generateFieldString("Right keys", rightKeys)}
47+
|${ExplainUtils.generateFieldString("Join condition", joinCondStr)}
48+
""".stripMargin
49+
} else {
50+
s"""
51+
|(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)}
52+
|${ExplainUtils.generateFieldString("Join condition", joinCondStr)}
53+
""".stripMargin
54+
}
55+
}
56+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
2626
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2727
import org.apache.spark.sql.catalyst.plans._
2828
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
29-
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan}
29+
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan}
3030
import org.apache.spark.sql.execution.metric.SQLMetrics
3131
import org.apache.spark.sql.types.{BooleanType, LongType}
3232

@@ -44,7 +44,7 @@ case class BroadcastHashJoinExec(
4444
condition: Option[Expression],
4545
left: SparkPlan,
4646
right: SparkPlan)
47-
extends BinaryExecNode with HashJoin with CodegenSupport {
47+
extends HashJoin with CodegenSupport {
4848

4949
override lazy val metrics = Map(
5050
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans._
2525
import org.apache.spark.sql.catalyst.plans.physical._
26-
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
26+
import org.apache.spark.sql.execution.{ExplainUtils, SparkPlan}
2727
import org.apache.spark.sql.execution.metric.SQLMetrics
2828
import org.apache.spark.util.collection.{BitSet, CompactBuffer}
2929

@@ -32,7 +32,10 @@ case class BroadcastNestedLoopJoinExec(
3232
right: SparkPlan,
3333
buildSide: BuildSide,
3434
joinType: JoinType,
35-
condition: Option[Expression]) extends BinaryExecNode {
35+
condition: Option[Expression]) extends BaseJoinExec {
36+
37+
override def leftKeys: Seq[Expression] = Nil
38+
override def rightKeys: Seq[Expression] = Nil
3639

3740
override lazy val metrics = Map(
3841
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -43,6 +46,11 @@ case class BroadcastNestedLoopJoinExec(
4346
case BuildLeft => (right, left)
4447
}
4548

49+
override def simpleStringWithNodeId(): String = {
50+
val opId = ExplainUtils.getOpId(this)
51+
s"$nodeName $joinType ${buildSide} ($opId)".trim
52+
}
53+
4654
override def requiredChildDistribution: Seq[Distribution] = buildSide match {
4755
case BuildLeft =>
4856
BroadcastDistribution(IdentityBroadcastMode) :: UnspecifiedDistribution :: Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, Predicate, UnsafeRow}
2424
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
25-
import org.apache.spark.sql.execution.{BinaryExecNode, ExplainUtils, ExternalAppendOnlyUnsafeRowArray, SparkPlan}
25+
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
26+
import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan}
2627
import org.apache.spark.sql.execution.metric.SQLMetrics
2728
import org.apache.spark.util.CompletionIterator
2829

@@ -60,23 +61,17 @@ class UnsafeCartesianRDD(
6061
case class CartesianProductExec(
6162
left: SparkPlan,
6263
right: SparkPlan,
63-
condition: Option[Expression]) extends BinaryExecNode {
64+
condition: Option[Expression]) extends BaseJoinExec {
65+
66+
override def joinType: JoinType = Inner
67+
override def leftKeys: Seq[Expression] = Nil
68+
override def rightKeys: Seq[Expression] = Nil
69+
6470
override def output: Seq[Attribute] = left.output ++ right.output
6571

6672
override lazy val metrics = Map(
6773
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
6874

69-
override def verboseStringWithOperatorId(): String = {
70-
val joinCondStr = if (condition.isDefined) {
71-
s"${condition.get}"
72-
} else "None"
73-
74-
s"""
75-
|(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)}
76-
|${ExplainUtils.generateFieldString("Join condition", joinCondStr)}
77-
""".stripMargin
78-
}
79-
8075
protected override def doExecute(): RDD[InternalRow] = {
8176
val numOutputRows = longMetric("numOutputRows")
8277

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,39 +22,18 @@ import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
2323
import org.apache.spark.sql.catalyst.plans._
2424
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
25-
import org.apache.spark.sql.execution.{ExplainUtils, RowIterator, SparkPlan}
25+
import org.apache.spark.sql.execution.{ExplainUtils, RowIterator}
2626
import org.apache.spark.sql.execution.metric.SQLMetric
2727
import org.apache.spark.sql.types.{IntegralType, LongType}
2828

29-
trait HashJoin {
30-
self: SparkPlan =>
31-
32-
def leftKeys: Seq[Expression]
33-
def rightKeys: Seq[Expression]
34-
def joinType: JoinType
29+
trait HashJoin extends BaseJoinExec {
3530
def buildSide: BuildSide
36-
def condition: Option[Expression]
37-
def left: SparkPlan
38-
def right: SparkPlan
3931

4032
override def simpleStringWithNodeId(): String = {
4133
val opId = ExplainUtils.getOpId(this)
4234
s"$nodeName $joinType ${buildSide} ($opId)".trim
4335
}
4436

45-
override def verboseStringWithOperatorId(): String = {
46-
val joinCondStr = if (condition.isDefined) {
47-
s"${condition.get}"
48-
} else "None"
49-
50-
s"""
51-
|(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)}
52-
|${ExplainUtils.generateFieldString("Left keys", leftKeys)}
53-
|${ExplainUtils.generateFieldString("Right keys", rightKeys)}
54-
|${ExplainUtils.generateFieldString("Join condition", joinCondStr)}
55-
""".stripMargin
56-
}
57-
5837
override def output: Seq[Attribute] = {
5938
joinType match {
6039
case _: InnerLike =>

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions.Expression
2626
import org.apache.spark.sql.catalyst.plans._
2727
import org.apache.spark.sql.catalyst.plans.physical._
28-
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
28+
import org.apache.spark.sql.execution.SparkPlan
2929
import org.apache.spark.sql.execution.metric.SQLMetrics
3030

3131
/**
@@ -39,7 +39,7 @@ case class ShuffledHashJoinExec(
3939
condition: Option[Expression],
4040
left: SparkPlan,
4141
right: SparkPlan)
42-
extends BinaryExecNode with HashJoin {
42+
extends HashJoin {
4343

4444
override lazy val metrics = Map(
4545
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ case class SortMergeJoinExec(
4141
condition: Option[Expression],
4242
left: SparkPlan,
4343
right: SparkPlan,
44-
isSkewJoin: Boolean = false) extends BinaryExecNode with CodegenSupport {
44+
isSkewJoin: Boolean = false) extends BaseJoinExec with CodegenSupport {
4545

4646
override lazy val metrics = Map(
4747
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -52,23 +52,6 @@ case class SortMergeJoinExec(
5252

5353
override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator
5454

55-
override def simpleStringWithNodeId(): String = {
56-
val opId = ExplainUtils.getOpId(this)
57-
s"$nodeName $joinType ($opId)".trim
58-
}
59-
60-
override def verboseStringWithOperatorId(): String = {
61-
val joinCondStr = if (condition.isDefined) {
62-
s"${condition.get}"
63-
} else "None"
64-
s"""
65-
|(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)}
66-
|${ExplainUtils.generateFieldString("Left keys", leftKeys)}
67-
|${ExplainUtils.generateFieldString("Right keys", rightKeys)}
68-
|${ExplainUtils.generateFieldString("Join condition", joinCondStr)}
69-
""".stripMargin
70-
}
71-
7255
override def output: Seq[Attribute] = {
7356
joinType match {
7457
case _: InnerLike =>

0 commit comments

Comments
 (0)