@@ -28,21 +28,30 @@ import org.apache.spark.sql.catalyst.plans.physical._
28
28
import org .apache .spark .sql .catalyst .rules .Rule
29
29
import org .apache .spark .util .MutablePair
30
30
31
+ object Exchange {
32
+ /** Returns true when the ordering expressions are a subset of the key. */
33
+ def canSortWithShuffle (partitioning : Partitioning , desiredOrdering : Seq [SortOrder ]): Boolean = {
34
+ desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet)
35
+ }
36
+ }
37
+
31
38
/**
32
- * Shuffle data according to a new partition rule, and sort inside each partition if necessary.
33
- * @param newPartitioning The new partitioning way that required by parent
34
- * @param sort Whether we will sort inside each partition
35
- * @param child Child operator
39
+ * :: DeveloperApi ::
40
+ * Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each
41
+ * resulting partition based on expressions from the partition key. It is invalid to construct an
42
+ * exchange operator with a `newOrdering` that cannot be calculated using the partitioning key.
36
43
*/
37
44
@ DeveloperApi
38
45
case class Exchange (
39
46
newPartitioning : Partitioning ,
40
- sort : Boolean ,
47
+ newOrdering : Seq [ SortOrder ] ,
41
48
child : SparkPlan )
42
49
extends UnaryNode {
43
50
44
51
override def outputPartitioning : Partitioning = newPartitioning
45
52
53
+ override def outputOrdering : Seq [SortOrder ] = newOrdering
54
+
46
55
override def output : Seq [Attribute ] = child.output
47
56
48
57
/** We must copy rows when sort based shuffle is on */
@@ -51,6 +60,20 @@ case class Exchange(
51
60
private val bypassMergeThreshold =
52
61
child.sqlContext.sparkContext.conf.getInt(" spark.shuffle.sort.bypassMergeThreshold" , 200 )
53
62
63
+ private val keyOrdering = {
64
+ if (newOrdering.nonEmpty) {
65
+ val key = newPartitioning.keyExpressions
66
+ val boundOrdering = newOrdering.map { o =>
67
+ val ordinal = key.indexOf(o.child)
68
+ if (ordinal == - 1 ) sys.error(s " Invalid ordering on $o requested for $newPartitioning" )
69
+ o.copy(child = BoundReference (ordinal, o.child.dataType, o.child.nullable))
70
+ }
71
+ new RowOrdering (boundOrdering)
72
+ } else {
73
+ null // Ordering will not be used
74
+ }
75
+ }
76
+
54
77
override def execute (): RDD [Row ] = attachTree(this , " execute" ) {
55
78
newPartitioning match {
56
79
case HashPartitioning (expressions, numPartitions) =>
@@ -62,7 +85,9 @@ case class Exchange(
62
85
// we can avoid the defensive copies to improve performance. In the long run, we probably
63
86
// want to include information in shuffle dependencies to indicate whether elements in the
64
87
// source RDD should be copied.
65
- val rdd = if ((sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || sort) {
88
+ val willMergeSort = sortBasedShuffleOn && numPartitions > bypassMergeThreshold
89
+
90
+ val rdd = if (willMergeSort || newOrdering.nonEmpty) {
66
91
child.execute().mapPartitions { iter =>
67
92
val hashExpressions = newMutableProjection(expressions, child.output)()
68
93
iter.map(r => (hashExpressions(r).copy(), r.copy()))
@@ -75,21 +100,17 @@ case class Exchange(
75
100
}
76
101
}
77
102
val part = new HashPartitioner (numPartitions)
78
- val shuffled = sort match {
79
- case false => new ShuffledRDD [Row , Row , Row ](rdd, part)
80
- case true =>
81
- val sortingExpressions = expressions.zipWithIndex.map {
82
- case (exp, index) =>
83
- new SortOrder (BoundReference (index, exp.dataType, exp.nullable), Ascending )
84
- }
85
- val ordering = new RowOrdering (sortingExpressions, child.output)
86
- new ShuffledRDD [Row , Row , Row ](rdd, part).setKeyOrdering(ordering)
87
- }
103
+ val shuffled =
104
+ if (newOrdering.nonEmpty) {
105
+ new ShuffledRDD [Row , Row , Row ](rdd, part).setKeyOrdering(keyOrdering)
106
+ } else {
107
+ new ShuffledRDD [Row , Row , Row ](rdd, part)
108
+ }
88
109
shuffled.setSerializer(new SparkSqlSerializer (new SparkConf (false )))
89
110
shuffled.map(_._2)
90
111
91
112
case RangePartitioning (sortingExpressions, numPartitions) =>
92
- val rdd = if (sortBasedShuffleOn) {
113
+ val rdd = if (sortBasedShuffleOn || newOrdering.nonEmpty ) {
93
114
child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null ))}
94
115
} else {
95
116
child.execute().mapPartitions { iter =>
@@ -102,7 +123,12 @@ case class Exchange(
102
123
implicit val ordering = new RowOrdering (sortingExpressions, child.output)
103
124
104
125
val part = new RangePartitioner (numPartitions, rdd, ascending = true )
105
- val shuffled = new ShuffledRDD [Row , Null , Null ](rdd, part)
126
+ val shuffled =
127
+ if (newOrdering.nonEmpty) {
128
+ new ShuffledRDD [Row , Null , Null ](rdd, part).setKeyOrdering(keyOrdering)
129
+ } else {
130
+ new ShuffledRDD [Row , Null , Null ](rdd, part)
131
+ }
106
132
shuffled.setSerializer(new SparkSqlSerializer (new SparkConf (false )))
107
133
108
134
shuffled.map(_._1)
@@ -135,27 +161,35 @@ case class Exchange(
135
161
* Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning ]]
136
162
* of input data meets the
137
163
* [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution ]] requirements for
138
- * each operator by inserting [[Exchange ]] Operators where required.
164
+ * each operator by inserting [[Exchange ]] Operators where required. Also ensure that the
165
+ * required input partition ordering requirements are met.
139
166
*/
140
- private [sql] case class AddExchange (sqlContext : SQLContext ) extends Rule [SparkPlan ] {
167
+ private [sql] case class EnsureRequirements (sqlContext : SQLContext ) extends Rule [SparkPlan ] {
141
168
// TODO: Determine the number of partitions.
142
169
def numPartitions : Int = sqlContext.conf.numShufflePartitions
143
170
144
171
def apply (plan : SparkPlan ): SparkPlan = plan.transformUp {
145
172
case operator : SparkPlan =>
146
- // Check if every child's outputPartitioning satisfies the corresponding
173
+ // True iff every child's outputPartitioning satisfies the corresponding
147
174
// required data distribution.
148
175
def meetsRequirements : Boolean =
149
- ! operator.requiredChildDistribution.zip(operator.children).map {
176
+ operator.requiredChildDistribution.zip(operator.children).forall {
150
177
case (required, child) =>
151
178
val valid = child.outputPartitioning.satisfies(required)
152
179
logDebug(
153
180
s " ${if (valid) " Valid" else " Invalid" } distribution, " +
154
181
s " required: $required current: ${child.outputPartitioning}" )
155
182
valid
156
- }.exists(! _)
183
+ }
184
+
185
+ // True iff any of the children are incorrectly sorted.
186
+ def needsAnySort : Boolean =
187
+ operator.requiredChildOrdering.zip(operator.children).exists {
188
+ case (required, child) => required.nonEmpty && required != child
189
+ }
190
+
157
191
158
- // Check if outputPartitionings of children are compatible with each other.
192
+ // True iff outputPartitionings of children are compatible with each other.
159
193
// It is possible that every child satisfies its required data distribution
160
194
// but two children have incompatible outputPartitionings. For example,
161
195
// A dataset is range partitioned by "a.asc" (RangePartitioning) and another
@@ -172,40 +206,61 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
172
206
case Seq (a,b) => a compatibleWith b
173
207
}.exists(! _)
174
208
175
- // Check if the partitioning we want to ensure is the same as the child's output
176
- // partitioning. If so, we do not need to add the Exchange operator.
177
- def addExchangeIfNecessary (
209
+ // Adds Exchange or Sort operators as required
210
+ def addOperatorsIfNecessary (
178
211
partitioning : Partitioning ,
179
- child : SparkPlan ,
180
- rowOrdering : Option [Ordering [Row ]] = None ): SparkPlan = {
181
- val needSort = child.outputOrdering != rowOrdering
182
- if (child.outputPartitioning != partitioning || needSort) {
183
- // TODO: if only needSort, we need only sort each partition instead of an Exchange
184
- Exchange (partitioning, sort = needSort, child)
212
+ rowOrdering : Seq [SortOrder ],
213
+ child : SparkPlan ): SparkPlan = {
214
+ val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering
215
+ val needsShuffle = child.outputPartitioning != partitioning
216
+ val canSortWithShuffle = Exchange .canSortWithShuffle(partitioning, rowOrdering)
217
+
218
+ if (needSort && needsShuffle && canSortWithShuffle) {
219
+ Exchange (partitioning, rowOrdering, child)
185
220
} else {
186
- child
221
+ val withShuffle = if (needsShuffle) {
222
+ Exchange (partitioning, Nil , child)
223
+ } else {
224
+ child
225
+ }
226
+
227
+ val withSort = if (needSort) {
228
+ Sort (rowOrdering, global = false , withShuffle)
229
+ } else {
230
+ withShuffle
231
+ }
232
+
233
+ withSort
187
234
}
188
235
}
189
236
190
- if (meetsRequirements && compatible) {
237
+ if (meetsRequirements && compatible && ! needsAnySort ) {
191
238
operator
192
239
} else {
193
240
// At least one child does not satisfies its required data distribution or
194
241
// at least one child's outputPartitioning is not compatible with another child's
195
242
// outputPartitioning. In this case, we need to add Exchange operators.
196
- val repartitionedChildren = operator.requiredChildDistribution.zip(
197
- operator.children.zip(operator.requiredChildOrdering)
198
- ).map {
199
- case (AllTuples , (child, _)) =>
200
- addExchangeIfNecessary(SinglePartition , child)
201
- case (ClusteredDistribution (clustering), (child, rowOrdering)) =>
202
- addExchangeIfNecessary(HashPartitioning (clustering, numPartitions), child, rowOrdering)
203
- case (OrderedDistribution (ordering), (child, None )) =>
204
- addExchangeIfNecessary(RangePartitioning (ordering, numPartitions), child)
205
- case (UnspecifiedDistribution , (child, _)) => child
206
- case (dist, _) => sys.error(s " Don't know how to ensure $dist" )
243
+ val requirements =
244
+ (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)
245
+
246
+ val fixedChildren = requirements.zipped.map {
247
+ case (AllTuples , rowOrdering, child) =>
248
+ addOperatorsIfNecessary(SinglePartition , rowOrdering, child)
249
+ case (ClusteredDistribution (clustering), rowOrdering, child) =>
250
+ addOperatorsIfNecessary(HashPartitioning (clustering, numPartitions), rowOrdering, child)
251
+ case (OrderedDistribution (ordering), rowOrdering, child) =>
252
+ addOperatorsIfNecessary(RangePartitioning (ordering, numPartitions), Nil , child)
253
+
254
+ case (UnspecifiedDistribution , Seq (), child) =>
255
+ child
256
+ case (UnspecifiedDistribution , rowOrdering, child) =>
257
+ Sort (rowOrdering, global = false , child)
258
+
259
+ case (dist, ordering, _) =>
260
+ sys.error(s " Don't know how to ensure $dist with ordering $ordering" )
207
261
}
208
- operator.withNewChildren(repartitionedChildren)
262
+
263
+ operator.withNewChildren(fixedChildren)
209
264
}
210
265
}
211
266
}
0 commit comments