@@ -19,9 +19,10 @@ package org.apache.spark.sql.execution
19
19
20
20
import org .apache .spark .annotation .DeveloperApi
21
21
import org .apache .spark .shuffle .sort .SortShuffleManager
22
- import org .apache .spark .sql .catalyst . expressions
22
+ import org .apache .spark .sql .types . DataType
23
23
import org .apache .spark .{SparkEnv , HashPartitioner , RangePartitioner , SparkConf }
24
24
import org .apache .spark .rdd .{RDD , ShuffledRDD }
25
+ import org .apache .spark .serializer .Serializer
25
26
import org .apache .spark .sql .{SQLContext , Row }
26
27
import org .apache .spark .sql .catalyst .errors .attachTree
27
28
import org .apache .spark .sql .catalyst .expressions .{Attribute , RowOrdering }
@@ -45,6 +46,27 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
45
46
private val bypassMergeThreshold =
46
47
child.sqlContext.sparkContext.conf.getInt(" spark.shuffle.sort.bypassMergeThreshold" , 200 )
47
48
49
+ def serializer (
50
+ keySchema : Array [DataType ],
51
+ valueSchema : Array [DataType ],
52
+ numPartitions : Int ): Serializer = {
53
+ val useSqlSerializer2 =
54
+ ! (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) &&
55
+ child.sqlContext.conf.useSqlSerializer2 &&
56
+ SparkSqlSerializer2 .support(keySchema) &&
57
+ SparkSqlSerializer2 .support(valueSchema)
58
+
59
+ val serializer = if (useSqlSerializer2) {
60
+ logInfo(" Use ShuffleSerializer" )
61
+ new SparkSqlSerializer2 (keySchema, valueSchema)
62
+ } else {
63
+ logInfo(" Use SparkSqlSerializer" )
64
+ new SparkSqlSerializer (new SparkConf (false ))
65
+ }
66
+
67
+ serializer
68
+ }
69
+
48
70
override def execute (): RDD [Row ] = attachTree(this , " execute" ) {
49
71
newPartitioning match {
50
72
case HashPartitioning (expressions, numPartitions) =>
@@ -70,7 +92,11 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
70
92
}
71
93
val part = new HashPartitioner (numPartitions)
72
94
val shuffled = new ShuffledRDD [Row , Row , Row ](rdd, part)
73
- shuffled.setSerializer(new SparkSqlSerializer (new SparkConf (false )))
95
+
96
+ val keySchema = expressions.map(_.dataType).toArray
97
+ val valueSchema = child.output.map(_.dataType).toArray
98
+ shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions))
99
+
74
100
shuffled.map(_._2)
75
101
76
102
case RangePartitioning (sortingExpressions, numPartitions) =>
@@ -88,7 +114,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
88
114
89
115
val part = new RangePartitioner (numPartitions, rdd, ascending = true )
90
116
val shuffled = new ShuffledRDD [Row , Null , Null ](rdd, part)
91
- shuffled.setSerializer(new SparkSqlSerializer (new SparkConf (false )))
117
+
118
+ val keySchema = sortingExpressions.map(_.dataType).toArray
119
+ shuffled.setSerializer(serializer(keySchema, null , numPartitions))
92
120
93
121
shuffled.map(_._1)
94
122
@@ -107,7 +135,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
107
135
}
108
136
val partitioner = new HashPartitioner (1 )
109
137
val shuffled = new ShuffledRDD [Null , Row , Row ](rdd, partitioner)
110
- shuffled.setSerializer(new SparkSqlSerializer (new SparkConf (false )))
138
+
139
+ val valueSchema = child.output.map(_.dataType).toArray
140
+ shuffled.setSerializer(serializer(null , valueSchema, 1 ))
141
+
111
142
shuffled.map(_._2)
112
143
113
144
case _ => sys.error(s " Exchange not implemented for $newPartitioning" )
0 commit comments