Skip to content

Commit 5e62e3a

Browse files
committed
[SPARK-41162][SQL] Fix anti- and semi-join for self-join with aggregations
Rule `PushDownLeftSemiAntiJoin` should not push an anti-join below an `Aggregate` when the join condition references an attribute that exists in its right plan and its left plan's child. This usually happens when the anti-join / semi-join is a self-join while `DeduplicateRelations` cannot deduplicate those attributes (in this example due to the projection of `value` to `id`). This behaviour already exists for `Project` and `Union`, but `Aggregate` lacks this safety guard. Without this change, the optimizer creates an incorrect plan. This example fails with `distinct()` (an aggregation), and succeeds without `distinct()`, but both queries are identical: ```scala val ids = Seq(1, 2, 3).toDF("id").distinct() val result = ids.withColumn("id", $"id" + 1).join(ids, Seq("id"), "left_anti").collect() assert(result.length == 1) ``` With `distinct()`, rule `PushDownLeftSemiAntiJoin` creates a join condition `(value#907 + 1) = value#907`, which can never be true. This effectively removes the anti-join. **Before this PR:** The anti-join is fully removed from the plan. ``` == Physical Plan == AdaptiveSparkPlan (16) +- == Final Plan == LocalTableScan (1) (16) AdaptiveSparkPlan Output [1]: [id#900] Arguments: isFinalPlan=true ``` This is caused by `PushDownLeftSemiAntiJoin` adding join condition `(value#907 + 1) = value#907`, which is wrong as because `id#910` in `(id#910 + 1) AS id#912` exists in the right child of the join as well as in the left grandchild: ``` === Applying Rule org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin === !Join LeftAnti, (id#912 = id#910) Aggregate [id#910], [(id#910 + 1) AS id#912] !:- Aggregate [id#910], [(id#910 + 1) AS id#912] +- Project [value#907 AS id#910] !: +- Project [value#907 AS id#910] +- Join LeftAnti, ((value#907 + 1) = value#907) !: +- LocalRelation [value#907] :- LocalRelation [value#907] !+- Aggregate [id#910], [id#910] +- Aggregate [id#910], [id#910] ! +- Project [value#914 AS id#910] +- Project [value#914 AS id#910] ! +- LocalRelation [value#914] +- LocalRelation [value#914] ``` The right child of the join and in the left grandchild would become the children of the pushed-down join, which creates an invalid join condition. **After this PR:** Join condition `(id#910 + 1) AS id#912` is understood to become ambiguous as both sides of the prospect join contain `id#910`. Hence, the join is not pushed down. The rule is then not applied any more. The final plan contains the anti-join: ``` == Physical Plan == AdaptiveSparkPlan (24) +- == Final Plan == * BroadcastHashJoin LeftSemi BuildRight (14) :- * HashAggregate (7) : +- AQEShuffleRead (6) : +- ShuffleQueryStage (5), Statistics(sizeInBytes=48.0 B, rowCount=3) : +- Exchange (4) : +- * HashAggregate (3) : +- * Project (2) : +- * LocalTableScan (1) +- BroadcastQueryStage (13), Statistics(sizeInBytes=1024.0 KiB, rowCount=3) +- BroadcastExchange (12) +- * HashAggregate (11) +- AQEShuffleRead (10) +- ShuffleQueryStage (9), Statistics(sizeInBytes=48.0 B, rowCount=3) +- ReusedExchange (8) (8) ReusedExchange [Reuses operator id: 4] Output [1]: [id#898] (24) AdaptiveSparkPlan Output [1]: [id#900] Arguments: isFinalPlan=true ``` It fixes correctness. Unit tests in `DataFrameJoinSuite` and `LeftSemiAntiJoinPushDownSuite`. Closes apache#39131 from EnricoMi/branch-antijoin-selfjoin-fix. Authored-by: Enrico Minack <github@enrico.minack.dev> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 61e0348 commit 5e62e3a

File tree

3 files changed

+63
-25
lines changed

3 files changed

+63
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
5252
}
5353

5454
// LeftSemi/LeftAnti over Aggregate
55-
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), _, _)
55+
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), joinCond, _)
5656
if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty &&
57-
!agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) =>
57+
!agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) &&
58+
canPushThroughCondition(agg.children, joinCond, rightOp) =>
5859
val aliasMap = getAliasMap(agg)
5960
val canPushDownPredicate = (predicate: Expression) => {
6061
val replaced = replaceAlias(predicate, aliasMap)
@@ -100,11 +101,11 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
100101
}
101102

102103
/**
103-
* Check if we can safely push a join through a project or union by making sure that attributes
104-
* referred in join condition do not contain the same attributes as the plan they are moved
105-
* into. This can happen when both sides of join refers to the same source (self join). This
106-
* function makes sure that the join condition refers to attributes that are not ambiguous (i.e
107-
* present in both the legs of the join) or else the resultant plan will be invalid.
104+
* Check if we can safely push a join through a project, aggregate, or union by making sure that
105+
* attributes referred in join condition do not contain the same attributes as the plan they are
106+
* moved into. This can happen when both sides of join refers to the same source (self join).
107+
* This function makes sure that the join condition refers to attributes that are not ambiguous
108+
* (i.e present in both the legs of the join) or else the resultant plan will be invalid.
108109
*/
109110
private def canPushThroughCondition(
110111
plans: Seq[LogicalPlan],

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
2626
import org.apache.spark.sql.catalyst.rules._
2727
import org.apache.spark.sql.types.IntegerType
2828

29-
class LeftSemiPushdownSuite extends PlanTest {
29+
class LeftSemiAntiJoinPushDownSuite extends PlanTest {
3030

3131
object Optimize extends RuleExecutor[LogicalPlan] {
3232
val batches =
@@ -45,7 +45,7 @@ class LeftSemiPushdownSuite extends PlanTest {
4545
val testRelation1 = LocalRelation('d.int)
4646
val testRelation2 = LocalRelation('e.int)
4747

48-
test("Project: LeftSemiAnti join pushdown") {
48+
test("Project: LeftSemi join pushdown") {
4949
val originalQuery = testRelation
5050
.select(star())
5151
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
@@ -58,7 +58,7 @@ class LeftSemiPushdownSuite extends PlanTest {
5858
comparePlans(optimized, correctAnswer)
5959
}
6060

61-
test("Project: LeftSemiAnti join no pushdown because of non-deterministic proj exprs") {
61+
test("Project: LeftSemi join no pushdown - non-deterministic proj exprs") {
6262
val originalQuery = testRelation
6363
.select(Rand('a), 'b, 'c)
6464
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
@@ -67,7 +67,7 @@ class LeftSemiPushdownSuite extends PlanTest {
6767
comparePlans(optimized, originalQuery.analyze)
6868
}
6969

70-
test("Project: LeftSemiAnti join non correlated scalar subq") {
70+
test("Project: LeftSemi join pushdown - non-correlated scalar subq") {
7171
val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze)
7272
val originalQuery = testRelation
7373
.select(subq.as("sum"))
@@ -82,7 +82,7 @@ class LeftSemiPushdownSuite extends PlanTest {
8282
comparePlans(optimized, correctAnswer)
8383
}
8484

85-
test("Project: LeftSemiAnti join no pushdown - correlated scalar subq in projection list") {
85+
test("Project: LeftSemi join no pushdown - correlated scalar subq in projection list") {
8686
val testRelation2 = LocalRelation('e.int, 'f.int)
8787
val subqPlan = testRelation2.groupBy('e)(sum('f).as("sum")).where('e === 'a)
8888
val subqExpr = ScalarSubquery(subqPlan)
@@ -94,7 +94,7 @@ class LeftSemiPushdownSuite extends PlanTest {
9494
comparePlans(optimized, originalQuery.analyze)
9595
}
9696

97-
test("Aggregate: LeftSemiAnti join pushdown") {
97+
test("Aggregate: LeftSemi join pushdown") {
9898
val originalQuery = testRelation
9999
.groupBy('b)('b, sum('c))
100100
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
@@ -108,7 +108,7 @@ class LeftSemiPushdownSuite extends PlanTest {
108108
comparePlans(optimized, correctAnswer)
109109
}
110110

111-
test("Aggregate: LeftSemiAnti join no pushdown due to non-deterministic aggr expressions") {
111+
test("Aggregate: LeftSemi join no pushdown - non-deterministic aggr expressions") {
112112
val originalQuery = testRelation
113113
.groupBy('b)('b, Rand(10).as('c))
114114
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
@@ -141,7 +141,7 @@ class LeftSemiPushdownSuite extends PlanTest {
141141
comparePlans(optimized, originalQuery.analyze)
142142
}
143143

144-
test("LeftSemiAnti join over aggregate - no pushdown") {
144+
test("Aggregate: LeftSemi join no pushdown") {
145145
val originalQuery = testRelation
146146
.groupBy('b)('b, sum('c).as('sum))
147147
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 'd))
@@ -150,7 +150,7 @@ class LeftSemiPushdownSuite extends PlanTest {
150150
comparePlans(optimized, originalQuery.analyze)
151151
}
152152

153-
test("Aggregate: LeftSemiAnti join non-correlated scalar subq aggr exprs") {
153+
test("Aggregate: LeftSemi join pushdown - non-correlated scalar subq aggr exprs") {
154154
val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze)
155155
val originalQuery = testRelation
156156
.groupBy('a) ('a, subq.as("sum"))
@@ -165,7 +165,7 @@ class LeftSemiPushdownSuite extends PlanTest {
165165
comparePlans(optimized, correctAnswer)
166166
}
167167

168-
test("LeftSemiAnti join over Window") {
168+
test("Window: LeftSemi join pushdown") {
169169
val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame))
170170

171171
val originalQuery = testRelation
@@ -183,7 +183,7 @@ class LeftSemiPushdownSuite extends PlanTest {
183183
comparePlans(optimized, correctAnswer)
184184
}
185185

186-
test("Window: LeftSemi partial pushdown") {
186+
test("Window: LeftSemi join partial pushdown") {
187187
// Attributes from join condition which does not refer to the window partition spec
188188
// are kept up in the plan as a Filter operator above Window.
189189
val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame))
@@ -223,7 +223,7 @@ class LeftSemiPushdownSuite extends PlanTest {
223223
comparePlans(optimized, correctAnswer)
224224
}
225225

226-
test("Union: LeftSemiAnti join pushdown") {
226+
test("Union: LeftSemi join pushdown") {
227227
val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int)
228228

229229
val originalQuery = Union(Seq(testRelation, testRelation2))
@@ -249,7 +249,7 @@ class LeftSemiPushdownSuite extends PlanTest {
249249
comparePlans(optimized, originalQuery.analyze)
250250
}
251251

252-
test("Unary: LeftSemiAnti join pushdown") {
252+
test("Unary: LeftSemi join pushdown") {
253253
val originalQuery = testRelation
254254
.select(star())
255255
.repartition(1)
@@ -264,7 +264,7 @@ class LeftSemiPushdownSuite extends PlanTest {
264264
comparePlans(optimized, correctAnswer)
265265
}
266266

267-
test("Unary: LeftSemiAnti join pushdown - empty join condition") {
267+
test("Unary: LeftSemi join pushdown - empty join condition") {
268268
val originalQuery = testRelation
269269
.select(star())
270270
.repartition(1)
@@ -279,7 +279,7 @@ class LeftSemiPushdownSuite extends PlanTest {
279279
comparePlans(optimized, correctAnswer)
280280
}
281281

282-
test("Unary: LeftSemi join pushdown - partial pushdown") {
282+
test("Unary: LeftSemi join partial pushdown") {
283283
val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
284284
val originalQuery = testRelationWithArrayType
285285
.generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col"))
@@ -295,7 +295,7 @@ class LeftSemiPushdownSuite extends PlanTest {
295295
comparePlans(optimized, correctAnswer)
296296
}
297297

298-
test("Unary: LeftAnti join pushdown - no pushdown") {
298+
test("Unary: LeftAnti join no pushdown") {
299299
val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
300300
val originalQuery = testRelationWithArrayType
301301
.generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col"))
@@ -305,7 +305,7 @@ class LeftSemiPushdownSuite extends PlanTest {
305305
comparePlans(optimized, originalQuery.analyze)
306306
}
307307

308-
test("Unary: LeftSemiAnti join pushdown - no pushdown") {
308+
test("Unary: LeftSemi join - no pushdown") {
309309
val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
310310
val originalQuery = testRelationWithArrayType
311311
.generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col"))
@@ -315,7 +315,7 @@ class LeftSemiPushdownSuite extends PlanTest {
315315
comparePlans(optimized, originalQuery.analyze)
316316
}
317317

318-
test("Unary: LeftSemi join push down through Expand") {
318+
test("Unary: LeftSemi join pushdown through Expand") {
319319
val expand = Expand(Seq(Seq('a, 'b, "null"), Seq('a, "null", 'c)),
320320
Seq('a, 'b, 'c), testRelation)
321321
val originalQuery = expand
@@ -421,6 +421,25 @@ class LeftSemiPushdownSuite extends PlanTest {
421421
}
422422
}
423423

424+
Seq(LeftSemi, LeftAnti).foreach { case jt =>
425+
test(s"Aggregate: $jt join no pushdown - join condition refers left leg and right leg child") {
426+
val aggregation = testRelation
427+
.select('b.as("id"), 'c)
428+
.groupBy('id)('id, sum('c).as("sum"))
429+
430+
// reference "b" exists in left leg, and the children of the right leg of the join
431+
val originalQuery = aggregation.select(('id + 1).as("id_plus_1"), 'sum)
432+
.join(aggregation, joinType = jt, condition = Some('id === 'id_plus_1))
433+
val optimized = Optimize.execute(originalQuery.analyze)
434+
val correctAnswer = testRelation
435+
.select('b.as("id"), 'c)
436+
.groupBy('id)(('id + 1).as("id_plus_1"), sum('c).as("sum"))
437+
.join(aggregation, joinType = jt, condition = Some('id === 'id_plus_1))
438+
.analyze
439+
comparePlans(optimized, correctAnswer)
440+
}
441+
}
442+
424443
Seq(LeftSemi, LeftAnti).foreach { case outerJT =>
425444
Seq(Inner, LeftOuter, RightOuter, Cross).foreach { case innerJT =>
426445
test(s"$outerJT no pushdown - join condition refers none of the leg - join type $innerJT") {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,24 @@ class DataFrameJoinSuite extends QueryTest
288288
}
289289
}
290290

291+
Seq("left_semi", "left_anti").foreach { joinType =>
292+
test(s"SPARK-41162: $joinType self-joined aggregated dataframe") {
293+
// aggregated dataframe
294+
val ids = Seq(1, 2, 3).toDF("id").distinct()
295+
296+
// self-joined via joinType
297+
val result = ids.withColumn("id", $"id" + 1)
298+
.join(ids, usingColumns = Seq("id"), joinType = joinType).collect()
299+
300+
val expected = joinType match {
301+
case "left_semi" => 2
302+
case "left_anti" => 1
303+
case _ => -1 // unsupported test type, test will always fail
304+
}
305+
assert(result.length == expected)
306+
}
307+
}
308+
291309
def extractLeftDeepInnerJoins(plan: LogicalPlan): Seq[LogicalPlan] = plan match {
292310
case j @ Join(left, right, _: InnerLike, _, _) => right +: extractLeftDeepInnerJoins(left)
293311
case Filter(_, child) => extractLeftDeepInnerJoins(child)

0 commit comments

Comments
 (0)