Skip to content

Commit 645c70b

Browse files
committed
address comments using sort
1 parent 068c35d commit 645c70b

File tree

5 files changed

+42
-77
lines changed

5 files changed

+42
-77
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -75,21 +75,6 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
7575
def clustering: Set[Expression] = ordering.map(_.child).toSet
7676
}
7777

78-
/**
79-
* Represents data where tuples have been ordered according to the `clustering`
80-
* [[Expression Expressions]]. This is a strictly stronger guarantee than
81-
* [[ClusteredDistribution]] as this will ensure that tuples in a single partition are sorted
82-
* by the expressions.
83-
*/
84-
case class ClusteredOrderedDistribution(clustering: Seq[Expression])
85-
extends Distribution {
86-
require(
87-
clustering != Nil,
88-
"The clustering expressions of a ClusteredOrderedDistribution should not be Nil. " +
89-
"An AllTuples should be used to represent a distribution that only has " +
90-
"a single partition.")
91-
}
92-
9378
sealed trait Partitioning {
9479
/** Returns the number of partitions that the data is split across */
9580
val numPartitions: Int
@@ -177,40 +162,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
177162
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
178163
}
179164

180-
/**
181-
* Represents a partitioning where rows are split up across partitions based on the hash
182-
* of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
183-
* in the same partition. And rows within the same partition are sorted by the expressions.
184-
*/
185-
case class HashSortedPartitioning(expressions: Seq[Expression], numPartitions: Int)
186-
extends Expression
187-
with Partitioning {
188-
189-
override def children: Seq[Expression] = expressions
190-
override def nullable: Boolean = false
191-
override def dataType: DataType = IntegerType
192-
193-
private[this] lazy val clusteringSet = expressions.toSet
194-
195-
override def satisfies(required: Distribution): Boolean = required match {
196-
case UnspecifiedDistribution => true
197-
case ClusteredOrderedDistribution(requiredClustering) =>
198-
clusteringSet.subsetOf(requiredClustering.toSet)
199-
case ClusteredDistribution(requiredClustering) =>
200-
clusteringSet.subsetOf(requiredClustering.toSet)
201-
case _ => false
202-
}
203-
204-
override def compatibleWith(other: Partitioning) = other match {
205-
case BroadcastPartitioning => true
206-
case h: HashSortedPartitioning if h == this => true
207-
case _ => false
208-
}
209-
210-
override def eval(input: Row = null): EvaluatedType =
211-
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
212-
}
213-
214165
/**
215166
* Represents a partitioning where rows are split across partitions based on some total ordering of
216167
* the expressions specified in `ordering`. When data is partitioned in this manner the following

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -72,29 +72,6 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
7272
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
7373
shuffled.map(_._2)
7474

75-
case HashSortedPartitioning(expressions, numPartitions) =>
76-
val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) {
77-
child.execute().mapPartitions { iter =>
78-
val hashExpressions = newMutableProjection(expressions, child.output)()
79-
iter.map(r => (hashExpressions(r).copy(), r.copy()))
80-
}
81-
} else {
82-
child.execute().mapPartitions { iter =>
83-
val hashExpressions = newMutableProjection(expressions, child.output)()
84-
val mutablePair = new MutablePair[Row, Row]()
85-
iter.map(r => mutablePair.update(hashExpressions(r), r))
86-
}
87-
}
88-
val sortingExpressions = expressions.zipWithIndex.map {
89-
case (exp, index) =>
90-
new SortOrder(BoundReference(index, exp.dataType, exp.nullable), Ascending)
91-
}
92-
val ordering = new RowOrdering(sortingExpressions, child.output)
93-
val part = new HashPartitioner(numPartitions)
94-
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering)
95-
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
96-
shuffled.map(_._2)
97-
9875
case RangePartitioning(sortingExpressions, numPartitions) =>
9976
val rdd = if (sortBasedShuffleOn) {
10077
child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))}
@@ -184,6 +161,11 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
184161
def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan =
185162
if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child
186163

164+
// Check if the partitioning we want to ensure is the same as the child's output
165+
// partitioning. If so, we do not need to add the Exchange operator.
166+
def addSortIfNecessary(ordering: Seq[SortOrder], child: SparkPlan): SparkPlan =
167+
if (child.outputOrdering != ordering) Sort(ordering, global = false, child) else child
168+
187169
if (meetsRequirements && compatible) {
188170
operator
189171
} else {
@@ -195,14 +177,18 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
195177
addExchangeIfNecessary(SinglePartition, child)
196178
case (ClusteredDistribution(clustering), child) =>
197179
addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child)
198-
case (ClusteredOrderedDistribution(clustering), child) =>
199-
addExchangeIfNecessary(HashSortedPartitioning(clustering, numPartitions), child)
200180
case (OrderedDistribution(ordering), child) =>
201181
addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child)
202182
case (UnspecifiedDistribution, child) => child
203183
case (dist, _) => sys.error(s"Don't know how to ensure $dist")
204184
}
205-
operator.withNewChildren(repartitionedChildren)
185+
val reorderedChildren = operator.requiredInPartitionOrdering.zip(repartitionedChildren).map {
186+
case (Nil, child) =>
187+
child
188+
case (ordering, child) =>
189+
addSortIfNecessary(ordering, child)
190+
}
191+
operator.withNewChildren(reorderedChildren)
206192
}
207193
}
208194
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
7272
def requiredChildDistribution: Seq[Distribution] =
7373
Seq.fill(children.size)(UnspecifiedDistribution)
7474

75+
/** Specifies how data is ordered in each partition. */
76+
def outputOrdering: Seq[SortOrder] = Nil
77+
78+
/** Specifies sort order for each partition requirements on the input data for this operator. */
79+
def requiredInPartitionOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
80+
7581
/**
7682
* Runs this query returning the result as an RDD.
7783
*/

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,25 @@ case class SortMergeJoin(
3939

4040
override def output: Seq[Attribute] = left.output ++ right.output
4141

42-
override def outputPartitioning: Partitioning = HashSortedPartitioning(leftKeys, 0)
42+
override def outputPartitioning: Partitioning = left.outputPartitioning
4343

4444
override def requiredChildDistribution: Seq[Distribution] =
45-
ClusteredOrderedDistribution(leftKeys) :: ClusteredOrderedDistribution(rightKeys) :: Nil
45+
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
4646

4747
private val orders: Seq[SortOrder] = leftKeys.zipWithIndex.map {
4848
case(expr, index) => SortOrder(BoundReference(index, expr.dataType, expr.nullable), Ascending)
4949
}
5050
private val ordering: RowOrdering = new RowOrdering(orders, left.output)
5151

52+
private def requiredOrders(keys: Seq[Expression], side: SparkPlan): Seq[SortOrder] = keys.map {
53+
k => SortOrder(BindReferences.bindReference(k, side.output, allowFailures = false), Ascending)
54+
}
55+
56+
override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys, left)
57+
58+
override def requiredInPartitionOrdering: Seq[Seq[SortOrder]] =
59+
requiredOrders(leftKeys, left) :: requiredOrders(rightKeys, right) :: Nil
60+
5261
@transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
5362
@transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)
5463

sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,19 @@ class SortMergeCompatibilitySuite extends HiveCompatibilitySuite {
7070
"auto_sortmerge_join_7",
7171
"auto_sortmerge_join_8",
7272
"auto_sortmerge_join_9",
73+
"correlationoptimizer1",
74+
"correlationoptimizer10",
75+
"correlationoptimizer11",
76+
"correlationoptimizer13",
77+
"correlationoptimizer14",
78+
"correlationoptimizer15",
79+
"correlationoptimizer2",
80+
"correlationoptimizer3",
81+
"correlationoptimizer4",
82+
"correlationoptimizer6",
83+
"correlationoptimizer7",
84+
"correlationoptimizer8",
85+
"correlationoptimizer9",
7386
"join0",
7487
"join1",
7588
"join10",

0 commit comments

Comments
 (0)