@@ -34,8 +34,8 @@ import org.apache.spark.util.MutablePair
34
34
@ DeveloperApi
35
35
case class Exchange (
36
36
newPartitioning : Partitioning ,
37
- child : SparkPlan ,
38
- sort : Boolean = false )
37
+ sort : Boolean ,
38
+ child : SparkPlan )
39
39
extends UnaryNode {
40
40
41
41
override def outputPartitioning : Partitioning = newPartitioning
@@ -59,7 +59,7 @@ case class Exchange(
59
59
// we can avoid the defensive copies to improve performance. In the long run, we probably
60
60
// want to include information in shuffle dependencies to indicate whether elements in the
61
61
// source RDD should be copied.
62
- val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) {
62
+ val rdd = if (( sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || sort ) {
63
63
child.execute().mapPartitions { iter =>
64
64
val hashExpressions = newMutableProjection(expressions, child.output)()
65
65
iter.map(r => (hashExpressions(r).copy(), r.copy()))
@@ -178,7 +178,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
178
178
val needSort = child.outputOrdering != rowOrdering
179
179
if (child.outputPartitioning != partitioning || needSort) {
180
180
// TODO: if only needSort, we need only sort each partition instead of an Exchange
181
- Exchange (partitioning, child, sort = needSort)
181
+ Exchange (partitioning, sort = needSort, child )
182
182
} else {
183
183
child
184
184
}
@@ -197,7 +197,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
197
197
addExchangeIfNecessary(SinglePartition , child)
198
198
case (ClusteredDistribution (clustering), (child, rowOrdering)) =>
199
199
addExchangeIfNecessary(HashPartitioning (clustering, numPartitions), child, rowOrdering)
200
- case (OrderedDistribution (ordering), (child, _ )) =>
200
+ case (OrderedDistribution (ordering), (child, None )) =>
201
201
addExchangeIfNecessary(RangePartitioning (ordering, numPartitions), child)
202
202
case (UnspecifiedDistribution , (child, _)) => child
203
203
case (dist, _) => sys.error(s " Don't know how to ensure $dist" )
0 commit comments