Skip to content

[SPARK-34707][SQL] Code-gen broadcast nested loop join (left outer/right outer) #31931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,8 @@ case class BroadcastNestedLoopJoinExec(
}

override def supportCodegen: Boolean = (joinType, buildSide) match {
case (_: InnerLike, _) | (LeftSemi | LeftAnti, BuildRight) => true
case (_: InnerLike, _) | (LeftOuter, BuildRight) | (RightOuter, BuildLeft) |
(LeftSemi | LeftAnti, BuildRight) => true
case _ => false
}

Expand All @@ -413,6 +414,7 @@ case class BroadcastNestedLoopJoinExec(
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
(joinType, buildSide) match {
case (_: InnerLike, _) => codegenInner(ctx, input)
case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => codegenOuter(ctx, input)
case (LeftSemi, BuildRight) => codegenLeftExistence(ctx, input, exists = true)
case (LeftAnti, BuildRight) => codegenLeftExistence(ctx, input, exists = false)
case _ =>
Expand Down Expand Up @@ -458,6 +460,49 @@ case class BroadcastNestedLoopJoinExec(
""".stripMargin
}

private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val (buildRowArray, buildRowArrayTerm) = prepareBroadcast(ctx)
val (buildRow, checkCondition, _) = getJoinCondition(ctx, input, streamed, broadcast)
val buildVars = genBuildSideVars(ctx, buildRow, broadcast)

val resultVars = buildSide match {
case BuildLeft => buildVars ++ input
case BuildRight => input ++ buildVars
}
val arrayIndex = ctx.freshName("arrayIndex")
val shouldOutputRow = ctx.freshName("shouldOutputRow")
val foundMatch = ctx.freshName("foundMatch")
val numOutput = metricTerm(ctx, "numOutputRows")

if (buildRowArray.isEmpty) {
s"""
|UnsafeRow $buildRow = null;
|$numOutput.add(1);
|${consume(ctx, resultVars)}
""".stripMargin
} else {
s"""
|boolean $foundMatch = false;
|for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) {
| UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex];
| boolean $shouldOutputRow = false;
| $checkCondition {
| $shouldOutputRow = true;
| $foundMatch = true;
| }
| if ($arrayIndex == $buildRowArrayTerm.length - 1 && !$foundMatch) {
| $buildRow = null;
| $shouldOutputRow = true;
| }
| if ($shouldOutputRow) {
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
| }
|}
""".stripMargin
}
}

private def codegenLeftExistence(
ctx: CodegenContext,
input: Seq[ExprCode],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,53 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
}
}

test("Left/Right outer BroadcastNestedLoopJoinExec should be included in WholeStageCodegen") {
val df1 = spark.range(4).select($"id".as("k1"))
val df2 = spark.range(3).select($"id".as("k2"))
val df3 = spark.range(2).select($"id".as("k3"))
val df4 = spark.range(0).select($"id".as("k4"))

Seq(true, false).foreach { codegenEnabled =>
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString) {
// test left outer join
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about adding an extra test case for broadcast side being empty?

val leftOuterJoinDF = df1.join(df2, $"k1" > $"k2", "left_outer")
var hasJoinInCodegen = leftOuterJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: BroadcastNestedLoopJoinExec) => true
}.size === 1
assert(hasJoinInCodegen == codegenEnabled)
checkAnswer(leftOuterJoinDF,
Seq(Row(0, null), Row(1, 0), Row(2, 0), Row(2, 1), Row(3, 0), Row(3, 1), Row(3, 2)))

// test right outer join
val rightOuterJoinDF = df1.join(df2, $"k1" < $"k2", "right_outer")
hasJoinInCodegen = rightOuterJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: BroadcastNestedLoopJoinExec) => true
}.size === 1
assert(hasJoinInCodegen == codegenEnabled)
checkAnswer(rightOuterJoinDF, Seq(Row(null, 0), Row(0, 1), Row(0, 2), Row(1, 2)))

// test a combination of left outer and right outer joins
val twoJoinsDF = df1.join(df2, $"k1" > $"k2" + 1, "right_outer")
.join(df3, $"k1" <= $"k3", "left_outer")
hasJoinInCodegen = twoJoinsDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(BroadcastNestedLoopJoinExec(
_: BroadcastNestedLoopJoinExec, _, _, _, _)) => true
}.size === 1
assert(hasJoinInCodegen == codegenEnabled)
checkAnswer(twoJoinsDF,
Seq(Row(2, 0, null), Row(3, 0, null), Row(3, 1, null), Row(null, 2, null)))

// test build side is empty
val buildSideIsEmptyDF = df3.join(df4, $"k3" > $"k4", "left_outer")
hasJoinInCodegen = buildSideIsEmptyDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_: BroadcastNestedLoopJoinExec) => true
}.size === 1
assert(hasJoinInCodegen == codegenEnabled)
checkAnswer(buildSideIsEmptyDF, Seq(Row(0, null), Row(1, null)))
}
}
}

test("Left semi/anti BroadcastNestedLoopJoinExec should be included in WholeStageCodegen") {
val df1 = spark.range(4).select($"id".as("k1"))
val df2 = spark.range(3).select($"id".as("k2"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession {
}
}

test(s"$testName using BroadcastNestedLoopJoin build left") {
testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastNestedLoopJoin build left") { _ =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastNestedLoopJoinExec(left, right, BuildLeft, joinType, Some(condition)),
Expand All @@ -158,7 +158,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession {
}
}

test(s"$testName using BroadcastNestedLoopJoin build right") {
testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastNestedLoopJoin build right") { _ =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastNestedLoopJoinExec(left, right, BuildRight, joinType, Some(condition)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,11 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
"testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a"
val rightQuery = "SELECT * FROM testData2 RIGHT JOIN testDataForJoin ON " +
"testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a"
Seq((leftQuery, false), (rightQuery, false), (leftQuery, true), (rightQuery, true))
.foreach { case (query, enableWholeStage) =>
Seq((leftQuery, 0L, false), (rightQuery, 0L, false), (leftQuery, 1L, true),
(rightQuery, 1L, true)).foreach { case (query, nodeId, enableWholeStage) =>
val df = spark.sql(query)
testSparkPlanMetrics(df, 2, Map(
0L -> (("BroadcastNestedLoopJoin", Map(
nodeId -> (("BroadcastNestedLoopJoin", Map(
"number of output rows" -> 12L)))),
enableWholeStage
)
Expand Down