Skip to content

Commit 43b9fb4

Browse files
committed
Test.
1 parent 8297732 commit 43b9fb4

File tree

1 file changed

+32
-66
lines changed

1 file changed

+32
-66
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala

+32-66
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@ package org.apache.spark.sql.execution
1919

2020
import java.sql.{Timestamp, Date}
2121

22-
import org.apache.spark.serializer.Serializer
23-
import org.apache.spark.{SparkEnv, SparkConf, ShuffleDependency, SparkContext}
22+
import org.scalatest.{FunSuite, BeforeAndAfterAll}
23+
2424
import org.apache.spark.rdd.ShuffledRDD
25+
import org.apache.spark.serializer.Serializer
26+
import org.apache.spark.ShuffleDependency
2527
import org.apache.spark.sql.types._
2628
import org.apache.spark.sql.Row
27-
import org.scalatest.{FunSuite, BeforeAndAfterAll}
28-
29-
import org.apache.spark.sql.{MyDenseVectorUDT, SQLContext, QueryTest}
29+
import org.apache.spark.sql.test.TestSQLContext._
30+
import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest}
3031

3132
class SparkSqlSerializer2DataTypeSuite extends FunSuite {
3233
// Make sure that we will not use serializer2 for unsupported data types.
@@ -67,18 +68,17 @@ class SparkSqlSerializer2DataTypeSuite extends FunSuite {
6768
}
6869

6970
abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll {
70-
71-
@transient var sparkContext: SparkContext = _
72-
@transient var sqlContext: SQLContext = _
73-
// We may have an existing SparkEnv (e.g. the one used by TestSQLContext).
74-
@transient val existingSparkEnv = SparkEnv.get
7571
var allColumns: String = _
7672
val serializerClass: Class[Serializer] =
7773
classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]]
74+
var numShufflePartitions: Int = _
75+
var useSerializer2: Boolean = _
7876

7977
override def beforeAll(): Unit = {
80-
sqlContext.sql("set spark.sql.shuffle.partitions=5")
81-
sqlContext.sql("set spark.sql.useSerializer2=true")
78+
numShufflePartitions = conf.numShufflePartitions
79+
useSerializer2 = conf.useSqlSerializer2
80+
81+
sql("set spark.sql.useSerializer2=true")
8282

8383
val supportedTypes =
8484
Seq(StringType, BinaryType, NullType, BooleanType,
@@ -112,18 +112,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
112112
new Timestamp(i))
113113
}
114114

115-
sqlContext.createDataFrame(rdd, schema).registerTempTable("shuffle")
115+
createDataFrame(rdd, schema).registerTempTable("shuffle")
116116

117117
super.beforeAll()
118118
}
119119

120120
override def afterAll(): Unit = {
121-
sqlContext.dropTempTable("shuffle")
122-
sparkContext.stop()
123-
sqlContext = null
124-
sparkContext = null
125-
// Set the existing SparkEnv back.
126-
SparkEnv.set(existingSparkEnv)
121+
dropTempTable("shuffle")
122+
sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions")
123+
sql(s"set spark.sql.useSerializer2=$useSerializer2")
127124
super.afterAll()
128125
}
129126

@@ -144,64 +141,40 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
144141
}
145142

146143
test("key schema and value schema are not nulls") {
147-
val df = sqlContext.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle")
144+
val df = sql(s"SELECT DISTINCT ${allColumns} FROM shuffle")
148145
checkSerializer(df.queryExecution.executedPlan, serializerClass)
149146
checkAnswer(
150147
df,
151-
sqlContext.table("shuffle").collect())
148+
table("shuffle").collect())
152149
}
153150

154151
test("value schema is null") {
155-
val df = sqlContext.sql(s"SELECT col0 FROM shuffle ORDER BY col0")
152+
val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0")
156153
checkSerializer(df.queryExecution.executedPlan, serializerClass)
157154
assert(
158155
df.map(r => r.getString(0)).collect().toSeq ===
159-
sqlContext.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq)
156+
table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq)
157+
}
158+
}
159+
160+
/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */
161+
class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
162+
override def beforeAll(): Unit = {
163+
super.beforeAll()
164+
// Sort merge will not be triggered.
165+
sql("set spark.sql.shuffle.partitions = 200")
160166
}
161167

162168
test("key schema is null") {
163169
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
164-
val df = sqlContext.sql(s"SELECT $aggregations FROM shuffle")
170+
val df = sql(s"SELECT $aggregations FROM shuffle")
165171
checkSerializer(df.queryExecution.executedPlan, serializerClass)
166172
checkAnswer(
167173
df,
168174
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
169175
}
170176
}
171177

172-
/** Tests SparkSqlSerializer2 with hash based shuffle. */
173-
class SparkSqlSerializer2HashShuffleSuite extends SparkSqlSerializer2Suite {
174-
override def beforeAll(): Unit = {
175-
val sparkConf =
176-
new SparkConf()
177-
.set("spark.driver.allowMultipleContexts", "true")
178-
.set("spark.sql.testkey", "true")
179-
.set("spark.shuffle.manager", "hash")
180-
181-
sparkContext = new SparkContext("local[2]", "Serializer2SQLContext", sparkConf)
182-
sqlContext = new SQLContext(sparkContext)
183-
super.beforeAll()
184-
}
185-
}
186-
187-
/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */
188-
class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
189-
override def beforeAll(): Unit = {
190-
// Since spark.sql.shuffle.partition is 5, we will not do sort merge when
191-
// spark.shuffle.sort.bypassMergeThreshold is also 5.
192-
val sparkConf =
193-
new SparkConf()
194-
.set("spark.driver.allowMultipleContexts", "true")
195-
.set("spark.sql.testkey", "true")
196-
.set("spark.shuffle.manager", "sort")
197-
.set("spark.shuffle.sort.bypassMergeThreshold", "5")
198-
199-
sparkContext = new SparkContext("local[2]", "Serializer2SQLContext", sparkConf)
200-
sqlContext = new SQLContext(sparkContext)
201-
super.beforeAll()
202-
}
203-
}
204-
205178
/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
206179
class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {
207180

@@ -210,15 +183,8 @@ class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite
210183
classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]]
211184

212185
override def beforeAll(): Unit = {
213-
val sparkConf =
214-
new SparkConf()
215-
.set("spark.driver.allowMultipleContexts", "true")
216-
.set("spark.sql.testkey", "true")
217-
.set("spark.shuffle.manager", "sort")
218-
.set("spark.shuffle.sort.bypassMergeThreshold", "0") // Always do sort merge.
219-
220-
sparkContext = new SparkContext("local[2]", "Serializer2SQLContext", sparkConf)
221-
sqlContext = new SQLContext(sparkContext)
222186
super.beforeAll()
187+
// To trigger the sort merge.
188+
sql("set spark.sql.shuffle.partitions = 201")
223189
}
224190
}

0 commit comments

Comments
 (0)