Skip to content

Commit 8681d73

Browse files
committed
refactor Exchange and fix copy for sorting
1 parent 2875ef2 commit 8681d73

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ import org.apache.spark.util.MutablePair
3434
@DeveloperApi
3535
case class Exchange(
3636
newPartitioning: Partitioning,
37-
child: SparkPlan,
38-
sort: Boolean = false)
37+
sort: Boolean,
38+
child: SparkPlan)
3939
extends UnaryNode {
4040

4141
override def outputPartitioning: Partitioning = newPartitioning
@@ -59,7 +59,7 @@ case class Exchange(
5959
// we can avoid the defensive copies to improve performance. In the long run, we probably
6060
// want to include information in shuffle dependencies to indicate whether elements in the
6161
// source RDD should be copied.
62-
val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) {
62+
val rdd = if ((sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || sort) {
6363
child.execute().mapPartitions { iter =>
6464
val hashExpressions = newMutableProjection(expressions, child.output)()
6565
iter.map(r => (hashExpressions(r).copy(), r.copy()))
@@ -178,7 +178,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
178178
val needSort = child.outputOrdering != rowOrdering
179179
if (child.outputPartitioning != partitioning || needSort) {
180180
// TODO: if only needSort, we need only sort each partition instead of an Exchange
181-
Exchange(partitioning, child, sort = needSort)
181+
Exchange(partitioning, sort = needSort, child)
182182
} else {
183183
child
184184
}
@@ -197,7 +197,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
197197
addExchangeIfNecessary(SinglePartition, child)
198198
case (ClusteredDistribution(clustering), (child, rowOrdering)) =>
199199
addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child, rowOrdering)
200-
case (OrderedDistribution(ordering), (child, _)) =>
200+
case (OrderedDistribution(ordering), (child, None)) =>
201201
addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child)
202202
case (UnspecifiedDistribution, (child, _)) => child
203203
case (dist, _) => sys.error(s"Don't know how to ensure $dist")

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
307307
case logical.OneRowRelation =>
308308
execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
309309
case logical.Repartition(expressions, child) =>
310-
execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
310+
execution.Exchange(
311+
HashPartitioning(expressions, numPartitions), sort = false, planLater(child)) :: Nil
311312
case e @ EvaluatePython(udf, child, _) =>
312313
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
313314
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil

0 commit comments

Comments
 (0)