@@ -19,17 +19,13 @@ package org.apache.spark.sql
19
19
20
20
import org .scalatest .BeforeAndAfterEach
21
21
22
- import org .apache .spark .sql .catalyst .analysis .UnresolvedRelation
23
22
import org .apache .spark .sql .TestData ._
24
- import org .apache .spark .sql .catalyst .plans .JoinType
25
- import org .apache .spark .sql .catalyst .plans .{LeftOuter , RightOuter , FullOuter , Inner , LeftSemi }
26
- import org .apache .spark .sql .execution ._
23
+ import org .apache .spark .sql .catalyst .analysis .UnresolvedRelation
24
+ import org .apache .spark .sql .catalyst .plans .{FullOuter , Inner , LeftOuter , RightOuter }
27
25
import org .apache .spark .sql .execution .joins ._
28
- import org .apache .spark .sql .test .TestSQLContext
29
26
import org .apache .spark .sql .test .TestSQLContext ._
30
27
31
28
class JoinSuite extends QueryTest with BeforeAndAfterEach {
32
-
33
29
// Ensures tables are loaded.
34
30
TestData
35
31
@@ -41,54 +37,65 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
41
37
assert(planned.size === 1 )
42
38
}
43
39
44
- test(" join operator selection" ) {
45
- def assertJoin (sqlString : String , c : Class [_]): Any = {
46
- val rdd = sql(sqlString)
47
- val physical = rdd.queryExecution.sparkPlan
48
- val operators = physical.collect {
49
- case j : ShuffledHashJoin => j
50
- case j : HashOuterJoin => j
51
- case j : LeftSemiJoinHash => j
52
- case j : BroadcastHashJoin => j
53
- case j : LeftSemiJoinBNL => j
54
- case j : CartesianProduct => j
55
- case j : BroadcastNestedLoopJoin => j
56
- }
57
-
58
- assert(operators.size === 1 )
59
- if (operators(0 ).getClass() != c) {
60
- fail(s " $sqlString expected operator: $c, but got ${operators(0 )}\n physical: \n $physical" )
61
- }
40
+ def assertJoin (sqlString : String , c : Class [_]): Any = {
41
+ val rdd = sql(sqlString)
42
+ val physical = rdd.queryExecution.sparkPlan
43
+ val operators = physical.collect {
44
+ case j : ShuffledHashJoin => j
45
+ case j : HashOuterJoin => j
46
+ case j : LeftSemiJoinHash => j
47
+ case j : BroadcastHashJoin => j
48
+ case j : LeftSemiJoinBNL => j
49
+ case j : CartesianProduct => j
50
+ case j : BroadcastNestedLoopJoin => j
51
+ }
52
+
53
+ assert(operators.size === 1 )
54
+ if (operators(0 ).getClass() != c) {
55
+ fail(s " $sqlString expected operator: $c, but got ${operators(0 )}\n physical: \n $physical" )
62
56
}
57
+ }
63
58
64
- val cases1 = Seq (
65
- (" SELECT * FROM testData left semi join testData2 ON key = a" , classOf [LeftSemiJoinHash ]),
66
- (" SELECT * FROM testData left semi join testData2" , classOf [LeftSemiJoinBNL ]),
67
- (" SELECT * FROM testData join testData2" , classOf [CartesianProduct ]),
68
- (" SELECT * FROM testData join testData2 where key=2" , classOf [CartesianProduct ]),
69
- (" SELECT * FROM testData left join testData2" , classOf [CartesianProduct ]),
70
- (" SELECT * FROM testData right join testData2" , classOf [CartesianProduct ]),
71
- (" SELECT * FROM testData full outer join testData2" , classOf [CartesianProduct ]),
72
- (" SELECT * FROM testData left join testData2 where key=2" , classOf [CartesianProduct ]),
73
- (" SELECT * FROM testData right join testData2 where key=2" , classOf [CartesianProduct ]),
74
- (" SELECT * FROM testData full outer join testData2 where key=2" , classOf [CartesianProduct ]),
75
- (" SELECT * FROM testData join testData2 where key>a" , classOf [CartesianProduct ]),
76
- (" SELECT * FROM testData full outer join testData2 where key>a" , classOf [CartesianProduct ]),
77
- (" SELECT * FROM testData join testData2 ON key = a" , classOf [ShuffledHashJoin ]),
78
- (" SELECT * FROM testData join testData2 ON key = a and key=2" , classOf [ShuffledHashJoin ]),
79
- (" SELECT * FROM testData join testData2 ON key = a where key=2" , classOf [ShuffledHashJoin ]),
80
- (" SELECT * FROM testData left join testData2 ON key = a" , classOf [HashOuterJoin ]),
81
- (" SELECT * FROM testData right join testData2 ON key = a where key=2" ,
59
+ test(" join operator selection" ) {
60
+ clearCache()
61
+
62
+ Seq (
63
+ (" SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a" , classOf [LeftSemiJoinHash ]),
64
+ (" SELECT * FROM testData LEFT SEMI JOIN testData2" , classOf [LeftSemiJoinBNL ]),
65
+ (" SELECT * FROM testData JOIN testData2" , classOf [CartesianProduct ]),
66
+ (" SELECT * FROM testData JOIN testData2 WHERE key = 2" , classOf [CartesianProduct ]),
67
+ (" SELECT * FROM testData LEFT JOIN testData2" , classOf [CartesianProduct ]),
68
+ (" SELECT * FROM testData RIGHT JOIN testData2" , classOf [CartesianProduct ]),
69
+ (" SELECT * FROM testData FULL OUTER JOIN testData2" , classOf [CartesianProduct ]),
70
+ (" SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2" , classOf [CartesianProduct ]),
71
+ (" SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2" , classOf [CartesianProduct ]),
72
+ (" SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2" , classOf [CartesianProduct ]),
73
+ (" SELECT * FROM testData JOIN testData2 WHERE key > a" , classOf [CartesianProduct ]),
74
+ (" SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a" , classOf [CartesianProduct ]),
75
+ (" SELECT * FROM testData JOIN testData2 ON key = a" , classOf [ShuffledHashJoin ]),
76
+ (" SELECT * FROM testData JOIN testData2 ON key = a and key = 2" , classOf [ShuffledHashJoin ]),
77
+ (" SELECT * FROM testData JOIN testData2 ON key = a where key = 2" , classOf [ShuffledHashJoin ]),
78
+ (" SELECT * FROM testData LEFT JOIN testData2 ON key = a" , classOf [HashOuterJoin ]),
79
+ (" SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2" ,
82
80
classOf [HashOuterJoin ]),
83
- (" SELECT * FROM testData right join testData2 ON key = a and key= 2" ,
81
+ (" SELECT * FROM testData right join testData2 ON key = a and key = 2" ,
84
82
classOf [HashOuterJoin ]),
85
- (" SELECT * FROM testData full outer join testData2 ON key = a" , classOf [HashOuterJoin ]),
86
- (" SELECT * FROM testData join testData2 ON key = a" , classOf [ShuffledHashJoin ]),
87
- (" SELECT * FROM testData join testData2 ON key = a and key=2" , classOf [ShuffledHashJoin ]),
88
- (" SELECT * FROM testData join testData2 ON key = a where key=2" , classOf [ShuffledHashJoin ])
89
- // TODO add BroadcastNestedLoopJoin
90
- )
91
- cases1.foreach { c => assertJoin(c._1, c._2) }
83
+ (" SELECT * FROM testData full outer join testData2 ON key = a" , classOf [HashOuterJoin ])
84
+ // TODO add BroadcastNestedLoopJoin
85
+ ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
86
+ }
87
+
88
+ test(" broadcasted hash join operator selection" ) {
89
+ clearCache()
90
+ sql(" CACHE TABLE testData" )
91
+
92
+ Seq (
93
+ (" SELECT * FROM testData join testData2 ON key = a" , classOf [BroadcastHashJoin ]),
94
+ (" SELECT * FROM testData join testData2 ON key = a and key = 2" , classOf [BroadcastHashJoin ]),
95
+ (" SELECT * FROM testData join testData2 ON key = a where key = 2" , classOf [BroadcastHashJoin ])
96
+ ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
97
+
98
+ sql(" UNCACHE TABLE testData" )
92
99
}
93
100
94
101
test(" multiple-key equi-join is hash-join" ) {
@@ -171,7 +178,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
171
178
(4 , " D" , 4 , " d" ) ::
172
179
(5 , " E" , null , null ) ::
173
180
(6 , " F" , null , null ) :: Nil )
174
-
181
+
175
182
checkAnswer(
176
183
upperCaseData.join(lowerCaseData, LeftOuter , Some (' n === ' N && ' n > 1 )),
177
184
(1 , " A" , null , null ) ::
@@ -180,7 +187,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
180
187
(4 , " D" , 4 , " d" ) ::
181
188
(5 , " E" , null , null ) ::
182
189
(6 , " F" , null , null ) :: Nil )
183
-
190
+
184
191
checkAnswer(
185
192
upperCaseData.join(lowerCaseData, LeftOuter , Some (' n === ' N && ' N > 1 )),
186
193
(1 , " A" , null , null ) ::
@@ -189,7 +196,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
189
196
(4 , " D" , 4 , " d" ) ::
190
197
(5 , " E" , null , null ) ::
191
198
(6 , " F" , null , null ) :: Nil )
192
-
199
+
193
200
checkAnswer(
194
201
upperCaseData.join(lowerCaseData, LeftOuter , Some (' n === ' N && ' l > ' L )),
195
202
(1 , " A" , 1 , " a" ) ::
@@ -300,7 +307,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
300
307
(4 , " D" , 4 , " D" ) ::
301
308
(null , null , 5 , " E" ) ::
302
309
(null , null , 6 , " F" ) :: Nil )
303
-
310
+
304
311
checkAnswer(
305
312
left.join(right, FullOuter , Some ((" left.N" .attr === " right.N" .attr) && (" left.N" .attr !== 3 ))),
306
313
(1 , " A" , null , null ) ::
@@ -310,7 +317,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
310
317
(4 , " D" , 4 , " D" ) ::
311
318
(null , null , 5 , " E" ) ::
312
319
(null , null , 6 , " F" ) :: Nil )
313
-
320
+
314
321
checkAnswer(
315
322
left.join(right, FullOuter , Some ((" left.N" .attr === " right.N" .attr) && (" right.N" .attr !== 3 ))),
316
323
(1 , " A" , null , null ) ::
0 commit comments