Skip to content

Commit 1d7bcc8

Browse files
lianchengmarmbrus
authored andcommitted
[SQL] Fixes caching related JoinSuite failure
PR #2860 refines in-memory table statistics and enables broader broadcasted hash join optimization for in-memory tables. This makes `JoinSuite` fail when some test suite caches test table `testData` and gets executed before `JoinSuite`. Because expected `ShuffledHashJoin`s are optimized to `BroadcastedHashJoin` according to collected in-memory table statistics. This PR fixes this issue by clearing the cache before testing join operator selection. A separate test case is also added to test broadcasted hash join operator selection. Author: Cheng Lian <lian@databricks.com> Closes #2960 from liancheng/fix-join-suite and squashes the following commits: 715b2de [Cheng Lian] Fixes caching related JoinSuite failure
1 parent dea302d commit 1d7bcc8

File tree

2 files changed

+64
-57
lines changed

2 files changed

+64
-57
lines changed

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 62 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,13 @@ package org.apache.spark.sql
1919

2020
import org.scalatest.BeforeAndAfterEach
2121

22-
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
2322
import org.apache.spark.sql.TestData._
24-
import org.apache.spark.sql.catalyst.plans.JoinType
25-
import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi}
26-
import org.apache.spark.sql.execution._
23+
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
24+
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter}
2725
import org.apache.spark.sql.execution.joins._
28-
import org.apache.spark.sql.test.TestSQLContext
2926
import org.apache.spark.sql.test.TestSQLContext._
3027

3128
class JoinSuite extends QueryTest with BeforeAndAfterEach {
32-
3329
// Ensures tables are loaded.
3430
TestData
3531

@@ -41,54 +37,65 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
4137
assert(planned.size === 1)
4238
}
4339

44-
test("join operator selection") {
45-
def assertJoin(sqlString: String, c: Class[_]): Any = {
46-
val rdd = sql(sqlString)
47-
val physical = rdd.queryExecution.sparkPlan
48-
val operators = physical.collect {
49-
case j: ShuffledHashJoin => j
50-
case j: HashOuterJoin => j
51-
case j: LeftSemiJoinHash => j
52-
case j: BroadcastHashJoin => j
53-
case j: LeftSemiJoinBNL => j
54-
case j: CartesianProduct => j
55-
case j: BroadcastNestedLoopJoin => j
56-
}
57-
58-
assert(operators.size === 1)
59-
if (operators(0).getClass() != c) {
60-
fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical")
61-
}
40+
def assertJoin(sqlString: String, c: Class[_]): Any = {
41+
val rdd = sql(sqlString)
42+
val physical = rdd.queryExecution.sparkPlan
43+
val operators = physical.collect {
44+
case j: ShuffledHashJoin => j
45+
case j: HashOuterJoin => j
46+
case j: LeftSemiJoinHash => j
47+
case j: BroadcastHashJoin => j
48+
case j: LeftSemiJoinBNL => j
49+
case j: CartesianProduct => j
50+
case j: BroadcastNestedLoopJoin => j
51+
}
52+
53+
assert(operators.size === 1)
54+
if (operators(0).getClass() != c) {
55+
fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical")
6256
}
57+
}
6358

64-
val cases1 = Seq(
65-
("SELECT * FROM testData left semi join testData2 ON key = a", classOf[LeftSemiJoinHash]),
66-
("SELECT * FROM testData left semi join testData2", classOf[LeftSemiJoinBNL]),
67-
("SELECT * FROM testData join testData2", classOf[CartesianProduct]),
68-
("SELECT * FROM testData join testData2 where key=2", classOf[CartesianProduct]),
69-
("SELECT * FROM testData left join testData2", classOf[CartesianProduct]),
70-
("SELECT * FROM testData right join testData2", classOf[CartesianProduct]),
71-
("SELECT * FROM testData full outer join testData2", classOf[CartesianProduct]),
72-
("SELECT * FROM testData left join testData2 where key=2", classOf[CartesianProduct]),
73-
("SELECT * FROM testData right join testData2 where key=2", classOf[CartesianProduct]),
74-
("SELECT * FROM testData full outer join testData2 where key=2", classOf[CartesianProduct]),
75-
("SELECT * FROM testData join testData2 where key>a", classOf[CartesianProduct]),
76-
("SELECT * FROM testData full outer join testData2 where key>a", classOf[CartesianProduct]),
77-
("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]),
78-
("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]),
79-
("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]),
80-
("SELECT * FROM testData left join testData2 ON key = a", classOf[HashOuterJoin]),
81-
("SELECT * FROM testData right join testData2 ON key = a where key=2",
59+
test("join operator selection") {
60+
clearCache()
61+
62+
Seq(
63+
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
64+
("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]),
65+
("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]),
66+
("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
67+
("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]),
68+
("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]),
69+
("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]),
70+
("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
71+
("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
72+
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
73+
("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
74+
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
75+
("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
76+
("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]),
77+
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]),
78+
("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]),
79+
("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
8280
classOf[HashOuterJoin]),
83-
("SELECT * FROM testData right join testData2 ON key = a and key=2",
81+
("SELECT * FROM testData right join testData2 ON key = a and key = 2",
8482
classOf[HashOuterJoin]),
85-
("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]),
86-
("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]),
87-
("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]),
88-
("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin])
89-
// TODO add BroadcastNestedLoopJoin
90-
)
91-
cases1.foreach { c => assertJoin(c._1, c._2) }
83+
("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin])
84+
// TODO add BroadcastNestedLoopJoin
85+
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
86+
}
87+
88+
test("broadcasted hash join operator selection") {
89+
clearCache()
90+
sql("CACHE TABLE testData")
91+
92+
Seq(
93+
("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]),
94+
("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]),
95+
("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin])
96+
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
97+
98+
sql("UNCACHE TABLE testData")
9299
}
93100

94101
test("multiple-key equi-join is hash-join") {
@@ -171,7 +178,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
171178
(4, "D", 4, "d") ::
172179
(5, "E", null, null) ::
173180
(6, "F", null, null) :: Nil)
174-
181+
175182
checkAnswer(
176183
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)),
177184
(1, "A", null, null) ::
@@ -180,7 +187,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
180187
(4, "D", 4, "d") ::
181188
(5, "E", null, null) ::
182189
(6, "F", null, null) :: Nil)
183-
190+
184191
checkAnswer(
185192
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)),
186193
(1, "A", null, null) ::
@@ -189,7 +196,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
189196
(4, "D", 4, "d") ::
190197
(5, "E", null, null) ::
191198
(6, "F", null, null) :: Nil)
192-
199+
193200
checkAnswer(
194201
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)),
195202
(1, "A", 1, "a") ::
@@ -300,7 +307,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
300307
(4, "D", 4, "D") ::
301308
(null, null, 5, "E") ::
302309
(null, null, 6, "F") :: Nil)
303-
310+
304311
checkAnswer(
305312
left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))),
306313
(1, "A", null, null) ::
@@ -310,7 +317,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
310317
(4, "D", 4, "D") ::
311318
(null, null, 5, "E") ::
312319
(null, null, 6, "F") :: Nil)
313-
320+
314321
checkAnswer(
315322
left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))),
316323
(1, "A", null, null) ::

sql/core/src/test/scala/org/apache/spark/sql/TestData.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ object TestData {
8080
UpperCaseData(3, "C") ::
8181
UpperCaseData(4, "D") ::
8282
UpperCaseData(5, "E") ::
83-
UpperCaseData(6, "F") :: Nil)
83+
UpperCaseData(6, "F") :: Nil).toSchemaRDD
8484
upperCaseData.registerTempTable("upperCaseData")
8585

8686
case class LowerCaseData(n: Int, l: String)
@@ -89,7 +89,7 @@ object TestData {
8989
LowerCaseData(1, "a") ::
9090
LowerCaseData(2, "b") ::
9191
LowerCaseData(3, "c") ::
92-
LowerCaseData(4, "d") :: Nil)
92+
LowerCaseData(4, "d") :: Nil).toSchemaRDD
9393
lowerCaseData.registerTempTable("lowerCaseData")
9494

9595
case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])

0 commit comments

Comments
 (0)