Skip to content

Commit 413fd24

Browse files
committed
Merge pull request #3 from marmbrus/pr/5208
Cleanup addition of ordering requirements
2 parents b198278 + 952168a commit 413fd24

File tree

8 files changed

+133
-70
lines changed

8 files changed

+133
-70
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] {
234234
}
235235

236236
object RowOrdering {
237-
def getOrderingFromDataTypes(dataTypes: Seq[DataType]): RowOrdering =
237+
def forSchema(dataTypes: Seq[DataType]): RowOrdering =
238238
new RowOrdering(dataTypes.zipWithIndex.map {
239239
case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending)
240240
})

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ sealed trait Partitioning {
9494
* only compatible if the `numPartitions` of them is the same.
9595
*/
9696
def compatibleWith(other: Partitioning): Boolean
97+
98+
/** Returns the expressions that are used to key the partitioning. */
99+
def keyExpressions: Seq[Expression]
97100
}
98101

99102
case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
@@ -106,6 +109,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
106109
case UnknownPartitioning(_) => true
107110
case _ => false
108111
}
112+
113+
override def keyExpressions: Seq[Expression] = Nil
109114
}
110115

111116
case object SinglePartition extends Partitioning {
@@ -117,6 +122,8 @@ case object SinglePartition extends Partitioning {
117122
case SinglePartition => true
118123
case _ => false
119124
}
125+
126+
override def keyExpressions: Seq[Expression] = Nil
120127
}
121128

122129
case object BroadcastPartitioning extends Partitioning {
@@ -128,6 +135,8 @@ case object BroadcastPartitioning extends Partitioning {
128135
case SinglePartition => true
129136
case _ => false
130137
}
138+
139+
override def keyExpressions: Seq[Expression] = Nil
131140
}
132141

133142
/**
@@ -158,6 +167,8 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
158167
case _ => false
159168
}
160169

170+
override def keyExpressions: Seq[Expression] = expressions
171+
161172
override def eval(input: Row = null): EvaluatedType =
162173
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
163174
}
@@ -200,6 +211,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
200211
case _ => false
201212
}
202213

214+
override def keyExpressions: Seq[Expression] = ordering.map(_.child)
215+
203216
override def eval(input: Row): EvaluatedType =
204217
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
205218
}

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1080,7 +1080,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
10801080
@transient
10811081
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
10821082
val batches =
1083-
Batch("Add exchange", Once, AddExchange(self)) :: Nil
1083+
Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil
10841084
}
10851085

10861086
protected[sql] def openSession(): SQLSession = {

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 102 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,30 @@ import org.apache.spark.sql.catalyst.plans.physical._
2828
import org.apache.spark.sql.catalyst.rules.Rule
2929
import org.apache.spark.util.MutablePair
3030

31+
object Exchange {
32+
/** Returns true when the ordering expressions are a subset of the key. */
33+
def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = {
34+
desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet)
35+
}
36+
}
37+
3138
/**
32-
* Shuffle data according to a new partition rule, and sort inside each partition if necessary.
33-
* @param newPartitioning The new partitioning way that required by parent
34-
* @param sort Whether we will sort inside each partition
35-
* @param child Child operator
39+
* :: DeveloperApi ::
40+
* Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each
41+
* resulting partition based on expressions from the partition key. It is invalid to construct an
42+
* exchange operator with a `newOrdering` that cannot be calculated using the partitioning key.
3643
*/
3744
@DeveloperApi
3845
case class Exchange(
3946
newPartitioning: Partitioning,
40-
sort: Boolean,
47+
newOrdering: Seq[SortOrder],
4148
child: SparkPlan)
4249
extends UnaryNode {
4350

4451
override def outputPartitioning: Partitioning = newPartitioning
4552

53+
override def outputOrdering: Seq[SortOrder] = newOrdering
54+
4655
override def output: Seq[Attribute] = child.output
4756

4857
/** We must copy rows when sort based shuffle is on */
@@ -51,6 +60,20 @@ case class Exchange(
5160
private val bypassMergeThreshold =
5261
child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
5362

63+
private val keyOrdering = {
64+
if (newOrdering.nonEmpty) {
65+
val key = newPartitioning.keyExpressions
66+
val boundOrdering = newOrdering.map { o =>
67+
val ordinal = key.indexOf(o.child)
68+
if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning")
69+
o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable))
70+
}
71+
new RowOrdering(boundOrdering)
72+
} else {
73+
null // Ordering will not be used
74+
}
75+
}
76+
5477
override def execute(): RDD[Row] = attachTree(this , "execute") {
5578
newPartitioning match {
5679
case HashPartitioning(expressions, numPartitions) =>
@@ -62,7 +85,9 @@ case class Exchange(
6285
// we can avoid the defensive copies to improve performance. In the long run, we probably
6386
// want to include information in shuffle dependencies to indicate whether elements in the
6487
// source RDD should be copied.
65-
val rdd = if ((sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || sort) {
88+
val willMergeSort = sortBasedShuffleOn && numPartitions > bypassMergeThreshold
89+
90+
val rdd = if (willMergeSort || newOrdering.nonEmpty) {
6691
child.execute().mapPartitions { iter =>
6792
val hashExpressions = newMutableProjection(expressions, child.output)()
6893
iter.map(r => (hashExpressions(r).copy(), r.copy()))
@@ -75,21 +100,17 @@ case class Exchange(
75100
}
76101
}
77102
val part = new HashPartitioner(numPartitions)
78-
val shuffled = sort match {
79-
case false => new ShuffledRDD[Row, Row, Row](rdd, part)
80-
case true =>
81-
val sortingExpressions = expressions.zipWithIndex.map {
82-
case (exp, index) =>
83-
new SortOrder(BoundReference(index, exp.dataType, exp.nullable), Ascending)
84-
}
85-
val ordering = new RowOrdering(sortingExpressions, child.output)
86-
new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering)
87-
}
103+
val shuffled =
104+
if (newOrdering.nonEmpty) {
105+
new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(keyOrdering)
106+
} else {
107+
new ShuffledRDD[Row, Row, Row](rdd, part)
108+
}
88109
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
89110
shuffled.map(_._2)
90111

91112
case RangePartitioning(sortingExpressions, numPartitions) =>
92-
val rdd = if (sortBasedShuffleOn) {
113+
val rdd = if (sortBasedShuffleOn || newOrdering.nonEmpty) {
93114
child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))}
94115
} else {
95116
child.execute().mapPartitions { iter =>
@@ -102,7 +123,12 @@ case class Exchange(
102123
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
103124

104125
val part = new RangePartitioner(numPartitions, rdd, ascending = true)
105-
val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
126+
val shuffled =
127+
if (newOrdering.nonEmpty) {
128+
new ShuffledRDD[Row, Null, Null](rdd, part).setKeyOrdering(keyOrdering)
129+
} else {
130+
new ShuffledRDD[Row, Null, Null](rdd, part)
131+
}
106132
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
107133

108134
shuffled.map(_._1)
@@ -135,27 +161,35 @@ case class Exchange(
135161
* Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
136162
* of input data meets the
137163
* [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
138-
* each operator by inserting [[Exchange]] Operators where required.
164+
* each operator by inserting [[Exchange]] Operators where required. Also ensure that the
165+
* required input partition ordering requirements are met.
139166
*/
140-
private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] {
167+
private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
141168
// TODO: Determine the number of partitions.
142169
def numPartitions: Int = sqlContext.conf.numShufflePartitions
143170

144171
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
145172
case operator: SparkPlan =>
146-
// Check if every child's outputPartitioning satisfies the corresponding
173+
// True iff every child's outputPartitioning satisfies the corresponding
147174
// required data distribution.
148175
def meetsRequirements: Boolean =
149-
!operator.requiredChildDistribution.zip(operator.children).map {
176+
operator.requiredChildDistribution.zip(operator.children).forall {
150177
case (required, child) =>
151178
val valid = child.outputPartitioning.satisfies(required)
152179
logDebug(
153180
s"${if (valid) "Valid" else "Invalid"} distribution," +
154181
s"required: $required current: ${child.outputPartitioning}")
155182
valid
156-
}.exists(!_)
183+
}
184+
185+
// True iff any of the children are incorrectly sorted.
186+
def needsAnySort: Boolean =
187+
operator.requiredChildOrdering.zip(operator.children).exists {
188+
case (required, child) => required.nonEmpty && required != child
189+
}
190+
157191

158-
// Check if outputPartitionings of children are compatible with each other.
192+
// True iff outputPartitionings of children are compatible with each other.
159193
// It is possible that every child satisfies its required data distribution
160194
// but two children have incompatible outputPartitionings. For example,
161195
// A dataset is range partitioned by "a.asc" (RangePartitioning) and another
@@ -172,40 +206,61 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
172206
case Seq(a,b) => a compatibleWith b
173207
}.exists(!_)
174208

175-
// Check if the partitioning we want to ensure is the same as the child's output
176-
// partitioning. If so, we do not need to add the Exchange operator.
177-
def addExchangeIfNecessary(
209+
// Adds Exchange or Sort operators as required
210+
def addOperatorsIfNecessary(
178211
partitioning: Partitioning,
179-
child: SparkPlan,
180-
rowOrdering: Option[Ordering[Row]] = None): SparkPlan = {
181-
val needSort = child.outputOrdering != rowOrdering
182-
if (child.outputPartitioning != partitioning || needSort) {
183-
// TODO: if only needSort, we need only sort each partition instead of an Exchange
184-
Exchange(partitioning, sort = needSort, child)
212+
rowOrdering: Seq[SortOrder],
213+
child: SparkPlan): SparkPlan = {
214+
val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering
215+
val needsShuffle = child.outputPartitioning != partitioning
216+
val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering)
217+
218+
if (needSort && needsShuffle && canSortWithShuffle) {
219+
Exchange(partitioning, rowOrdering, child)
185220
} else {
186-
child
221+
val withShuffle = if (needsShuffle) {
222+
Exchange(partitioning, Nil, child)
223+
} else {
224+
child
225+
}
226+
227+
val withSort = if (needSort) {
228+
Sort(rowOrdering, global = false, withShuffle)
229+
} else {
230+
withShuffle
231+
}
232+
233+
withSort
187234
}
188235
}
189236

190-
if (meetsRequirements && compatible) {
237+
if (meetsRequirements && compatible && !needsAnySort) {
191238
operator
192239
} else {
193240
// At least one child does not satisfies its required data distribution or
194241
// at least one child's outputPartitioning is not compatible with another child's
195242
// outputPartitioning. In this case, we need to add Exchange operators.
196-
val repartitionedChildren = operator.requiredChildDistribution.zip(
197-
operator.children.zip(operator.requiredChildOrdering)
198-
).map {
199-
case (AllTuples, (child, _)) =>
200-
addExchangeIfNecessary(SinglePartition, child)
201-
case (ClusteredDistribution(clustering), (child, rowOrdering)) =>
202-
addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child, rowOrdering)
203-
case (OrderedDistribution(ordering), (child, None)) =>
204-
addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child)
205-
case (UnspecifiedDistribution, (child, _)) => child
206-
case (dist, _) => sys.error(s"Don't know how to ensure $dist")
243+
val requirements =
244+
(operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)
245+
246+
val fixedChildren = requirements.zipped.map {
247+
case (AllTuples, rowOrdering, child) =>
248+
addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
249+
case (ClusteredDistribution(clustering), rowOrdering, child) =>
250+
addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
251+
case (OrderedDistribution(ordering), rowOrdering, child) =>
252+
addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), Nil, child)
253+
254+
case (UnspecifiedDistribution, Seq(), child) =>
255+
child
256+
case (UnspecifiedDistribution, rowOrdering, child) =>
257+
Sort(rowOrdering, global = false, child)
258+
259+
case (dist, ordering, _) =>
260+
sys.error(s"Don't know how to ensure $dist with ordering $ordering")
207261
}
208-
operator.withNewChildren(repartitionedChildren)
262+
263+
operator.withNewChildren(fixedChildren)
209264
}
210265
}
211266
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
7373
Seq.fill(children.size)(UnspecifiedDistribution)
7474

7575
/** Specifies how data is ordered in each partition. */
76-
def outputOrdering: Option[Ordering[Row]] = None
76+
def outputOrdering: Seq[SortOrder] = Nil
7777

7878
/** Specifies sort order for each partition requirements on the input data for this operator. */
79-
def requiredChildOrdering: Seq[Option[Ordering[Row]]] = Seq.fill(children.size)(None)
79+
def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
8080

8181
/**
8282
* Runs this query returning the result as an RDD.

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
308308
execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
309309
case logical.Repartition(expressions, child) =>
310310
execution.Exchange(
311-
HashPartitioning(expressions, numPartitions), sort = false, planLater(child)) :: Nil
311+
HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil
312312
case e @ EvaluatePython(udf, child, _) =>
313313
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
314314
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
4242
iter.map(resuableProjection)
4343
}
4444

45-
/**
46-
* outputOrdering of Project is not always same with child's outputOrdering if the certain
47-
* key is pruned, however, if the key is pruned then we must not require child using this
48-
* ordering from upper layer, so it is fine to keep it to avoid some unnecessary sorting.
49-
*/
50-
override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering
45+
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
5146
}
5247

5348
/**
@@ -63,7 +58,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
6358
iter.filter(conditionEvaluator)
6459
}
6560

66-
override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering
61+
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
6762
}
6863

6964
/**
@@ -111,7 +106,7 @@ case class Limit(limit: Int, child: SparkPlan)
111106
override def output: Seq[Attribute] = child.output
112107
override def outputPartitioning: Partitioning = SinglePartition
113108

114-
override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering
109+
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
115110

116111
override def executeCollect(): Array[Row] = child.executeTake(limit)
117112

@@ -158,7 +153,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
158153
// TODO: Pick num splits based on |limit|.
159154
override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1)
160155

161-
override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder))
156+
override def outputOrdering: Seq[SortOrder] = sortOrder
162157
}
163158

164159
/**
@@ -185,7 +180,7 @@ case class Sort(
185180

186181
override def output: Seq[Attribute] = child.output
187182

188-
override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder))
183+
override def outputOrdering: Seq[SortOrder] = sortOrder
189184
}
190185

191186
/**
@@ -217,7 +212,7 @@ case class ExternalSort(
217212

218213
override def output: Seq[Attribute] = child.output
219214

220-
override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder))
215+
override def outputOrdering: Seq[SortOrder] = sortOrder
221216
}
222217

223218
/**

0 commit comments

Comments
 (0)