Skip to content

Commit 078d69b

Browse files
committed
address comments: add comments, do sort in shuffle, and others
1 parent 3af6ba5 commit 078d69b

File tree

9 files changed

+93
-64
lines changed

9 files changed

+93
-64
lines changed

sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ private[spark] object SQLConf {
2727
val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize"
2828
val IN_MEMORY_PARTITION_PRUNING = "spark.sql.inMemoryColumnarStorage.partitionPruning"
2929
val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold"
30-
val AUTO_SORTMERGEJOIN = "spark.sql.autoSortMergeJoin"
3130
val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes"
3231
val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
3332
val CODEGEN_ENABLED = "spark.sql.codegen"
@@ -46,6 +45,7 @@ private[spark] object SQLConf {
4645
// Options that control which operators can be chosen by the query planner. These should be
4746
// considered hints and may be ignored by future versions of Spark SQL.
4847
val EXTERNAL_SORT = "spark.sql.planner.externalSort"
48+
val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin"
4949

5050
// This is only used for the thriftserver
5151
val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool"
@@ -123,6 +123,13 @@ private[sql] class SQLConf extends Serializable {
123123
/** When true the planner will use the external sort, which may spill to disk. */
124124
private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean
125125

126+
/**
127+
* Sort merge join would sort the two side of join first, and then iterate both sides together
128+
* only once to get all matches. Using sort merge join can save a lot of memory usage compared
129+
* to HashJoin.
130+
*/
131+
private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean
132+
126133
/**
127134
* When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode
128135
* that evaluates expressions found in queries. In general this custom code runs much faster
@@ -144,12 +151,6 @@ private[sql] class SQLConf extends Serializable {
144151
private[spark] def autoBroadcastJoinThreshold: Int =
145152
getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt
146153

147-
/**
148-
* By default not choose sort merge join.
149-
*/
150-
private[spark] def autoSortMergeJoin: Boolean =
151-
getConf(AUTO_SORTMERGEJOIN, false.toString).toBoolean
152-
153154
/**
154155
* The default size in bytes to assign to a logical operator's estimation statistics. By default,
155156
* it is set to a larger value than `autoBroadcastJoinThreshold`, hence any logical operator

sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ case class Aggregate(
6060

6161
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
6262

63-
override def outputOrdering: Seq[SortOrder] = Nil
64-
6563
/**
6664
* An aggregate that needs to be computed for each row in a group.
6765
*

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@ import org.apache.spark.util.MutablePair
3232
* :: DeveloperApi ::
3333
*/
3434
@DeveloperApi
35-
case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode {
35+
case class Exchange(
36+
newPartitioning: Partitioning,
37+
child: SparkPlan,
38+
sort: Boolean = false)
39+
extends UnaryNode {
3640

3741
override def outputPartitioning: Partitioning = newPartitioning
3842

@@ -68,7 +72,16 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
6872
}
6973
}
7074
val part = new HashPartitioner(numPartitions)
71-
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
75+
val shuffled = sort match {
76+
case false => new ShuffledRDD[Row, Row, Row](rdd, part)
77+
case true =>
78+
val sortingExpressions = expressions.zipWithIndex.map {
79+
case (exp, index) =>
80+
new SortOrder(BoundReference(index, exp.dataType, exp.nullable), Ascending)
81+
}
82+
val ordering = new RowOrdering(sortingExpressions, child.output)
83+
new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering)
84+
}
7285
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
7386
shuffled.map(_._2)
7487

@@ -158,37 +171,35 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
158171

159172
// Check if the partitioning we want to ensure is the same as the child's output
160173
// partitioning. If so, we do not need to add the Exchange operator.
161-
def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan =
162-
if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child
163-
164-
// Check if the ordering we want to ensure is the same as the child's output
165-
// ordering. If so, we do not need to add the Sort operator.
166-
def addSortIfNecessary(ordering: Seq[SortOrder], child: SparkPlan): SparkPlan =
167-
if (child.outputOrdering != ordering) Sort(ordering, global = false, child) else child
174+
def addExchangeIfNecessary(
175+
partitioning: Partitioning,
176+
child: SparkPlan,
177+
rowOrdering: Option[Ordering[Row]] = None): SparkPlan =
178+
if (child.outputPartitioning != partitioning) {
179+
Exchange(partitioning, child, sort = child.outputOrdering != rowOrdering)
180+
} else {
181+
child
182+
}
168183

169184
if (meetsRequirements && compatible) {
170185
operator
171186
} else {
172187
// At least one child does not satisfies its required data distribution or
173188
// at least one child's outputPartitioning is not compatible with another child's
174189
// outputPartitioning. In this case, we need to add Exchange operators.
175-
val repartitionedChildren = operator.requiredChildDistribution.zip(operator.children).map {
176-
case (AllTuples, child) =>
190+
val repartitionedChildren = operator.requiredChildDistribution.zip(
191+
operator.children.zip(operator.requiredChildOrdering)
192+
).map {
193+
case (AllTuples, (child, _)) =>
177194
addExchangeIfNecessary(SinglePartition, child)
178-
case (ClusteredDistribution(clustering), child) =>
179-
addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child)
180-
case (OrderedDistribution(ordering), child) =>
195+
case (ClusteredDistribution(clustering), (child, rowOrdering)) =>
196+
addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child, rowOrdering)
197+
case (OrderedDistribution(ordering), (child, _)) =>
181198
addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child)
182-
case (UnspecifiedDistribution, child) => child
199+
case (UnspecifiedDistribution, (child, _)) => child
183200
case (dist, _) => sys.error(s"Don't know how to ensure $dist")
184201
}
185-
val reorderedChildren =
186-
operator.requiredInPartitionOrdering.zip(repartitionedChildren).map {
187-
case (Nil, child) => child
188-
case (ordering, child) =>
189-
addSortIfNecessary(ordering, child)
190-
}
191-
operator.withNewChildren(reorderedChildren)
202+
operator.withNewChildren(repartitionedChildren)
192203
}
193204
}
194205
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
7373
Seq.fill(children.size)(UnspecifiedDistribution)
7474

7575
/** Specifies how data is ordered in each partition. */
76-
def outputOrdering: Seq[SortOrder] = Nil
76+
def outputOrdering: Option[Ordering[Row]] = None
7777

7878
/** Specifies sort order for each partition requirements on the input data for this operator. */
79-
def requiredInPartitionOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
79+
def requiredChildOrdering: Seq[Option[Ordering[Row]]] = Seq.fill(children.size)(None)
8080

8181
/**
8282
* Runs this query returning the result as an RDD.
@@ -183,7 +183,6 @@ private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] {
183183
private[sql] trait UnaryNode extends SparkPlan with trees.UnaryNode[SparkPlan] {
184184
self: Product =>
185185
override def outputPartitioning: Partitioning = child.outputPartitioning
186-
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
187186
}
188187

189188
private[sql] trait BinaryNode extends SparkPlan with trees.BinaryNode[SparkPlan] {

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
9292

9393
// for now let's support inner join first, then add outer join
9494
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
95-
if sqlContext.conf.autoSortMergeJoin =>
95+
if sqlContext.conf.sortMergeJoinEnabled =>
9696
val mergeJoin =
9797
joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
9898
condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
4141
val resuableProjection = buildProjection()
4242
iter.map(resuableProjection)
4343
}
44+
45+
/**
46+
* outputOrdering of Project is not always same with child's outputOrdering if the certain
47+
* key is pruned, however, if the key is pruned then we must not require child using this
48+
* ordering from upper layer, only if the ordering would not be changed by a negative, there
49+
* would be a way to keep the ordering.
50+
* TODO: we may utilize this feature later to avoid some unnecessary sorting.
51+
*/
52+
override def outputOrdering: Option[Ordering[Row]] = None
4453
}
4554

4655
/**
@@ -55,6 +64,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
5564
override def execute(): RDD[Row] = child.execute().mapPartitions { iter =>
5665
iter.filter(conditionEvaluator)
5766
}
67+
68+
override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering
5869
}
5970

6071
/**
@@ -70,8 +81,6 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child:
7081
override def execute(): RDD[Row] = {
7182
child.execute().map(_.copy()).sample(withReplacement, fraction, seed)
7283
}
73-
74-
override def outputOrdering: Seq[SortOrder] = Nil
7584
}
7685

7786
/**
@@ -104,6 +113,8 @@ case class Limit(limit: Int, child: SparkPlan)
104113
override def output: Seq[Attribute] = child.output
105114
override def outputPartitioning: Partitioning = SinglePartition
106115

116+
override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering
117+
107118
override def executeCollect(): Array[Row] = child.executeTake(limit)
108119

109120
override def execute(): RDD[Row] = {
@@ -149,7 +160,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
149160
// TODO: Pick num splits based on |limit|.
150161
override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1)
151162

152-
override def outputOrdering: Seq[SortOrder] = sortOrder
163+
override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder))
153164
}
154165

155166
/**
@@ -176,7 +187,7 @@ case class Sort(
176187

177188
override def output: Seq[Attribute] = child.output
178189

179-
override def outputOrdering: Seq[SortOrder] = sortOrder
190+
override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder))
180191
}
181192

182193
/**
@@ -208,7 +219,7 @@ case class ExternalSort(
208219

209220
override def output: Seq[Attribute] = child.output
210221

211-
override def outputOrdering: Seq[SortOrder] = sortOrder
222+
override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder))
212223
}
213224

214225
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.joins
1919

20+
import java.util.NoSuchElementException
21+
2022
import org.apache.spark.annotation.DeveloperApi
2123
import org.apache.spark.rdd.RDD
2224
import org.apache.spark.sql.Row
@@ -47,16 +49,16 @@ case class SortMergeJoin(
4749
private val orders: Seq[SortOrder] = leftKeys.zipWithIndex.map {
4850
case(expr, index) => SortOrder(BoundReference(index, expr.dataType, expr.nullable), Ascending)
4951
}
50-
private val ordering: RowOrdering = new RowOrdering(orders, left.output)
52+
// this is to manually construct an ordering that can be used to compare keys from both sides
53+
private val keyOrdering: RowOrdering = new RowOrdering(orders)
5154

52-
private def requiredOrders(keys: Seq[Expression], side: SparkPlan): Seq[SortOrder] = keys.map {
53-
k => SortOrder(BindReferences.bindReference(k, side.output, allowFailures = false), Ascending)
54-
}
55+
private def requiredOrders(keys: Seq[Expression], side: SparkPlan): Ordering[Row] =
56+
newOrdering(keys.map(SortOrder(_, Ascending)), side.output)
5557

56-
override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys, left)
58+
override def outputOrdering: Option[Ordering[Row]] = Some(requiredOrders(leftKeys, left))
5759

58-
override def requiredInPartitionOrdering: Seq[Seq[SortOrder]] =
59-
requiredOrders(leftKeys, left) :: requiredOrders(rightKeys, right) :: Nil
60+
override def requiredChildOrdering: Seq[Option[Ordering[Row]]] =
61+
Some(requiredOrders(leftKeys, left)) :: Some(requiredOrders(rightKeys, right)) :: Nil
6062

6163
@transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
6264
@transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)
@@ -78,24 +80,28 @@ case class SortMergeJoin(
7880
private[this] var stop: Boolean = false
7981
private[this] var matchKey: Row = _
8082

83+
// initialize iterator
84+
initialize()
85+
8186
override final def hasNext: Boolean = nextMatchingPair()
8287

8388
override final def next(): Row = {
8489
if (hasNext) {
90+
// we are using the buffered right rows and run down left iterator
8591
val joinedRow = joinRow(leftElement, rightMatches(rightPosition))
8692
rightPosition += 1
8793
if (rightPosition >= rightMatches.size) {
8894
rightPosition = 0
8995
fetchLeft()
90-
if (leftElement == null || ordering.compare(leftKey, matchKey) != 0) {
96+
if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) {
9197
stop = false
9298
rightMatches = null
9399
}
94100
}
95101
joinedRow
96102
} else {
97-
// according to Scala doc, this is undefined
98-
null
103+
// no more result
104+
throw new NoSuchElementException
99105
}
100106
}
101107

@@ -121,33 +127,36 @@ case class SortMergeJoin(
121127
fetchLeft()
122128
fetchRight()
123129
}
124-
// initialize iterator
125-
initialize()
126130

127131
/**
128-
* Searches the left/right iterator for the next rows that matches.
132+
* Searches the right iterator for the next rows that have matches in left side, and store
133+
* them in a buffer.
129134
*
130-
* @return true if the search is successful, and false if the left/right iterator runs out
131-
* of tuples.
135+
* @return true if the search is successful, and false if the right iterator runs out of
136+
* tuples.
132137
*/
133138
private def nextMatchingPair(): Boolean = {
134139
if (!stop && rightElement != null) {
140+
// run both side to get the first match pair
135141
while (!stop && leftElement != null && rightElement != null) {
136-
stop = ordering.compare(leftKey, rightKey) == 0 && !leftKey.anyNull
137-
if (ordering.compare(leftKey, rightKey) > 0 || rightKey.anyNull) {
142+
val comparing = keyOrdering.compare(leftKey, rightKey)
143+
// for inner join, we need to filter those null keys
144+
stop = comparing == 0 && !leftKey.anyNull
145+
if (comparing > 0 || rightKey.anyNull) {
138146
fetchRight()
139-
} else if (ordering.compare(leftKey, rightKey) < 0 || leftKey.anyNull) {
147+
} else if (comparing < 0 || leftKey.anyNull) {
140148
fetchLeft()
141149
}
142150
}
143151
rightMatches = new CompactBuffer[Row]()
144152
if (stop) {
145153
stop = false
154+
// iterate the right side to buffer all rows that matches
155+
// as the records should be ordered, exit when we meet the first that not match
146156
while (!stop && rightElement != null) {
147157
rightMatches += rightElement
148158
fetchRight()
149-
// exit loop when run out of right matches
150-
stop = ordering.compare(leftKey, rightKey) != 0
159+
stop = keyOrdering.compare(leftKey, rightKey) != 0
151160
}
152161
if (rightMatches.size > 0) {
153162
rightPosition = 0

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
6464
test("join operator selection") {
6565
cacheManager.clearCache()
6666

67-
val AUTO_SORTMERGEJOIN: Boolean = conf.autoSortMergeJoin
67+
val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled
6868
conf.setConf("spark.sql.autoSortMergeJoin", "false")
6969
Seq(
7070
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
@@ -103,7 +103,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
103103
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin])
104104
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
105105
} finally {
106-
conf.setConf("spark.sql.autoSortMergeJoin", AUTO_SORTMERGEJOIN.toString)
106+
conf.setConf("spark.sql.autoSortMergeJoin", SORTMERGEJOIN_ENABLED.toString)
107107
}
108108
}
109109

sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ import org.apache.spark.sql.hive.test.TestHive
2626
class SortMergeCompatibilitySuite extends HiveCompatibilitySuite {
2727
override def beforeAll() {
2828
super.beforeAll()
29-
TestHive.setConf(SQLConf.AUTO_SORTMERGEJOIN, "true")
29+
TestHive.setConf(SQLConf.SORTMERGE_JOIN, "true")
3030
}
3131

3232
override def afterAll() {
33-
TestHive.setConf(SQLConf.AUTO_SORTMERGEJOIN, "false")
33+
TestHive.setConf(SQLConf.SORTMERGE_JOIN, "false")
3434
super.afterAll()
3535
}
3636

0 commit comments

Comments
 (0)