@@ -19,14 +19,15 @@ package org.apache.spark.sql.execution
19
19
20
20
import java .sql .{Timestamp , Date }
21
21
22
- import org .apache . spark . serializer . Serializer
23
- import org . apache . spark .{ SparkEnv , SparkConf , ShuffleDependency , SparkContext }
22
+ import org .scalatest .{ FunSuite , BeforeAndAfterAll }
23
+
24
24
import org .apache .spark .rdd .ShuffledRDD
25
+ import org .apache .spark .serializer .Serializer
26
+ import org .apache .spark .ShuffleDependency
25
27
import org .apache .spark .sql .types ._
26
28
import org .apache .spark .sql .Row
27
- import org .scalatest .{FunSuite , BeforeAndAfterAll }
28
-
29
- import org .apache .spark .sql .{MyDenseVectorUDT , SQLContext , QueryTest }
29
+ import org .apache .spark .sql .test .TestSQLContext ._
30
+ import org .apache .spark .sql .{MyDenseVectorUDT , QueryTest }
30
31
31
32
class SparkSqlSerializer2DataTypeSuite extends FunSuite {
32
33
// Make sure that we will not use serializer2 for unsupported data types.
@@ -67,18 +68,17 @@ class SparkSqlSerializer2DataTypeSuite extends FunSuite {
67
68
}
68
69
69
70
abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll {
70
-
71
- @ transient var sparkContext : SparkContext = _
72
- @ transient var sqlContext : SQLContext = _
73
- // We may have an existing SparkEnv (e.g. the one used by TestSQLContext).
74
- @ transient val existingSparkEnv = SparkEnv .get
75
71
var allColumns : String = _
76
72
val serializerClass : Class [Serializer ] =
77
73
classOf [SparkSqlSerializer2 ].asInstanceOf [Class [Serializer ]]
74
+ var numShufflePartitions : Int = _
75
+ var useSerializer2 : Boolean = _
78
76
79
77
override def beforeAll (): Unit = {
80
- sqlContext.sql(" set spark.sql.shuffle.partitions=5" )
81
- sqlContext.sql(" set spark.sql.useSerializer2=true" )
78
+ numShufflePartitions = conf.numShufflePartitions
79
+ useSerializer2 = conf.useSqlSerializer2
80
+
81
+ sql(" set spark.sql.useSerializer2=true" )
82
82
83
83
val supportedTypes =
84
84
Seq (StringType , BinaryType , NullType , BooleanType ,
@@ -112,18 +112,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
112
112
new Timestamp (i))
113
113
}
114
114
115
- sqlContext. createDataFrame(rdd, schema).registerTempTable(" shuffle" )
115
+ createDataFrame(rdd, schema).registerTempTable(" shuffle" )
116
116
117
117
super .beforeAll()
118
118
}
119
119
120
120
override def afterAll (): Unit = {
121
- sqlContext.dropTempTable(" shuffle" )
122
- sparkContext.stop()
123
- sqlContext = null
124
- sparkContext = null
125
- // Set the existing SparkEnv back.
126
- SparkEnv .set(existingSparkEnv)
121
+ dropTempTable(" shuffle" )
122
+ sql(s " set spark.sql.shuffle.partitions= $numShufflePartitions" )
123
+ sql(s " set spark.sql.useSerializer2= $useSerializer2" )
127
124
super .afterAll()
128
125
}
129
126
@@ -144,64 +141,40 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
144
141
}
145
142
146
143
test(" key schema and value schema are not nulls" ) {
147
- val df = sqlContext. sql(s " SELECT DISTINCT ${allColumns} FROM shuffle " )
144
+ val df = sql(s " SELECT DISTINCT ${allColumns} FROM shuffle " )
148
145
checkSerializer(df.queryExecution.executedPlan, serializerClass)
149
146
checkAnswer(
150
147
df,
151
- sqlContext. table(" shuffle" ).collect())
148
+ table(" shuffle" ).collect())
152
149
}
153
150
154
151
test(" value schema is null" ) {
155
- val df = sqlContext. sql(s " SELECT col0 FROM shuffle ORDER BY col0 " )
152
+ val df = sql(s " SELECT col0 FROM shuffle ORDER BY col0 " )
156
153
checkSerializer(df.queryExecution.executedPlan, serializerClass)
157
154
assert(
158
155
df.map(r => r.getString(0 )).collect().toSeq ===
159
- sqlContext.table(" shuffle" ).select(" col0" ).map(r => r.getString(0 )).collect().sorted.toSeq)
156
+ table(" shuffle" ).select(" col0" ).map(r => r.getString(0 )).collect().sorted.toSeq)
157
+ }
158
+ }
159
+
160
+ /** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */
161
+ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
162
+ override def beforeAll (): Unit = {
163
+ super .beforeAll()
164
+ // Sort merge will not be triggered.
165
+ sql(" set spark.sql.shuffle.partitions = 200" )
160
166
}
161
167
162
168
test(" key schema is null" ) {
163
169
val aggregations = allColumns.split(" ," ).map(c => s " COUNT( $c) " ).mkString(" ," )
164
- val df = sqlContext. sql(s " SELECT $aggregations FROM shuffle " )
170
+ val df = sql(s " SELECT $aggregations FROM shuffle " )
165
171
checkSerializer(df.queryExecution.executedPlan, serializerClass)
166
172
checkAnswer(
167
173
df,
168
174
Row (1000 , 1000 , 0 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 ))
169
175
}
170
176
}
171
177
172
- /** Tests SparkSqlSerializer2 with hash based shuffle. */
173
- class SparkSqlSerializer2HashShuffleSuite extends SparkSqlSerializer2Suite {
174
- override def beforeAll (): Unit = {
175
- val sparkConf =
176
- new SparkConf ()
177
- .set(" spark.driver.allowMultipleContexts" , " true" )
178
- .set(" spark.sql.testkey" , " true" )
179
- .set(" spark.shuffle.manager" , " hash" )
180
-
181
- sparkContext = new SparkContext (" local[2]" , " Serializer2SQLContext" , sparkConf)
182
- sqlContext = new SQLContext (sparkContext)
183
- super .beforeAll()
184
- }
185
- }
186
-
187
- /** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */
188
- class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
189
- override def beforeAll (): Unit = {
190
- // Since spark.sql.shuffle.partition is 5, we will not do sort merge when
191
- // spark.shuffle.sort.bypassMergeThreshold is also 5.
192
- val sparkConf =
193
- new SparkConf ()
194
- .set(" spark.driver.allowMultipleContexts" , " true" )
195
- .set(" spark.sql.testkey" , " true" )
196
- .set(" spark.shuffle.manager" , " sort" )
197
- .set(" spark.shuffle.sort.bypassMergeThreshold" , " 5" )
198
-
199
- sparkContext = new SparkContext (" local[2]" , " Serializer2SQLContext" , sparkConf)
200
- sqlContext = new SQLContext (sparkContext)
201
- super .beforeAll()
202
- }
203
- }
204
-
205
178
/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
206
179
class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {
207
180
@@ -210,15 +183,8 @@ class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite
210
183
classOf [SparkSqlSerializer ].asInstanceOf [Class [Serializer ]]
211
184
212
185
override def beforeAll (): Unit = {
213
- val sparkConf =
214
- new SparkConf ()
215
- .set(" spark.driver.allowMultipleContexts" , " true" )
216
- .set(" spark.sql.testkey" , " true" )
217
- .set(" spark.shuffle.manager" , " sort" )
218
- .set(" spark.shuffle.sort.bypassMergeThreshold" , " 0" ) // Always do sort merge.
219
-
220
- sparkContext = new SparkContext (" local[2]" , " Serializer2SQLContext" , sparkConf)
221
- sqlContext = new SQLContext (sparkContext)
222
186
super .beforeAll()
187
+ // To trigger the sort merge.
188
+ sql(" set spark.sql.shuffle.partitions = 201" )
223
189
}
224
190
}
0 commit comments