Skip to content

Commit 880d8e9

Browse files
committed
sort merge join for spark sql
1 parent 5db8912 commit 880d8e9

File tree

6 files changed

+235
-5
lines changed

6 files changed

+235
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,21 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
7575
def clustering: Set[Expression] = ordering.map(_.child).toSet
7676
}
7777

78+
/**
79+
* Represents data where tuples have been ordered according to the `clustering`
80+
* [[Expression Expressions]]. This is a strictly stronger guarantee than
81+
* [[ClusteredDistribution]] as this will ensure that tuples in a single partition are sorted
82+
* by the expressions.
83+
*/
84+
case class ClusteredOrderedDistribution(clustering: Seq[Expression])
85+
extends Distribution {
86+
require(
87+
clustering != Nil,
88+
"The clustering expressions of a ClusteredOrderedDistribution should not be Nil. " +
89+
"An AllTuples should be used to represent a distribution that only has " +
90+
"a single partition.")
91+
}
92+
7893
sealed trait Partitioning {
7994
/** Returns the number of partitions that the data is split across */
8095
val numPartitions: Int
@@ -162,6 +177,40 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
162177
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
163178
}
164179

180+
/**
181+
* Represents a partitioning where rows are split up across partitions based on the hash
182+
* of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
183+
* in the same partition. And rows within the same partition are sorted by the expressions.
184+
*/
185+
case class HashSortedPartitioning(expressions: Seq[Expression], numPartitions: Int)
186+
extends Expression
187+
with Partitioning {
188+
189+
override def children = expressions
190+
override def nullable = false
191+
override def dataType = IntegerType
192+
193+
private[this] lazy val clusteringSet = expressions.toSet
194+
195+
override def satisfies(required: Distribution): Boolean = required match {
196+
case UnspecifiedDistribution => true
197+
case ClusteredOrderedDistribution(requiredClustering) =>
198+
clusteringSet.subsetOf(requiredClustering.toSet)
199+
case ClusteredDistribution(requiredClustering) =>
200+
clusteringSet.subsetOf(requiredClustering.toSet)
201+
case _ => false
202+
}
203+
204+
override def compatibleWith(other: Partitioning) = other match {
205+
case BroadcastPartitioning => true
206+
case h: HashSortedPartitioning if h == this => true
207+
case _ => false
208+
}
209+
210+
override def eval(input: Row = null): EvaluatedType =
211+
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
212+
}
213+
165214
/**
166215
* Represents a partitioning where rows are split across partitions based on some total ordering of
167216
* the expressions specified in `ordering`. When data is partitioned in this manner the following

sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ private[spark] object SQLConf {
2727
val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize"
2828
val IN_MEMORY_PARTITION_PRUNING = "spark.sql.inMemoryColumnarStorage.partitionPruning"
2929
val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold"
30+
val AUTO_SORTMERGEJOIN = "spark.sql.autoSortMergeJoin"
3031
val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes"
3132
val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
3233
val CODEGEN_ENABLED = "spark.sql.codegen"
@@ -143,6 +144,12 @@ private[sql] class SQLConf extends Serializable {
143144
private[spark] def autoBroadcastJoinThreshold: Int =
144145
getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt
145146

147+
/**
148+
* By default it will choose sort merge join.
149+
*/
150+
private[spark] def autoSortMergeJoin: Boolean =
151+
getConf(AUTO_SORTMERGEJOIN, true.toString).toBoolean
152+
146153
/**
147154
* The default size in bytes to assign to a logical operator's estimation statistics. By default,
148155
* it is set to a larger value than `autoBroadcastJoinThreshold`, hence any logical operator

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@ package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.annotation.DeveloperApi
2121
import org.apache.spark.shuffle.sort.SortShuffleManager
22-
import org.apache.spark.sql.catalyst.expressions
2322
import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf}
2423
import org.apache.spark.rdd.{RDD, ShuffledRDD}
2524
import org.apache.spark.sql.{SQLContext, Row}
2625
import org.apache.spark.sql.catalyst.errors.attachTree
27-
import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering}
26+
import org.apache.spark.sql.catalyst.expressions._
2827
import org.apache.spark.sql.catalyst.plans.physical._
2928
import org.apache.spark.sql.catalyst.rules.Rule
3029
import org.apache.spark.util.MutablePair
@@ -73,6 +72,26 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
7372
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
7473
shuffled.map(_._2)
7574

75+
case HashSortedPartitioning(expressions, numPartitions) =>
76+
val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) {
77+
child.execute().mapPartitions { iter =>
78+
val hashExpressions = newMutableProjection(expressions, child.output)()
79+
iter.map(r => (hashExpressions(r).copy(), r.copy()))
80+
}
81+
} else {
82+
child.execute().mapPartitions { iter =>
83+
val hashExpressions = newMutableProjection(expressions, child.output)()
84+
val mutablePair = new MutablePair[Row, Row]()
85+
iter.map(r => mutablePair.update(hashExpressions(r), r))
86+
}
87+
}
88+
val sortingExpressions = expressions.map(s => new SortOrder(s, Ascending))
89+
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
90+
val part = new HashPartitioner(numPartitions)
91+
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering)
92+
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
93+
shuffled.map(_._2)
94+
7695
case RangePartitioning(sortingExpressions, numPartitions) =>
7796
val rdd = if (sortBasedShuffleOn) {
7897
child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))}
@@ -173,6 +192,8 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
173192
addExchangeIfNecessary(SinglePartition, child)
174193
case (ClusteredDistribution(clustering), child) =>
175194
addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child)
195+
case (ClusteredOrderedDistribution(clustering), child) =>
196+
addExchangeIfNecessary(HashSortedPartitioning(clustering, numPartitions), child)
176197
case (OrderedDistribution(ordering), child) =>
177198
addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child)
178199
case (UnspecifiedDistribution, child) => child

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
9090
left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
9191
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
9292

93+
// for now let's support inner join first, then add outer join
94+
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
95+
if sqlContext.conf.autoSortMergeJoin =>
96+
val mergeJoin =
97+
joins.SortMergeJoin(leftKeys, rightKeys, Inner, planLater(left), planLater(right))
98+
condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
99+
93100
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
94101
val buildSide =
95102
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.joins
19+
20+
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.sql.Row
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.plans._
24+
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredOrderedDistribution, Partitioning}
25+
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
26+
import org.apache.spark.util.collection.CompactBuffer
27+
28+
/**
29+
* :: DeveloperApi ::
30+
* Performs an sort merge join of two child relations.
31+
*/
32+
@DeveloperApi
33+
case class SortMergeJoin(
34+
leftKeys: Seq[Expression],
35+
rightKeys: Seq[Expression],
36+
joinType: JoinType,
37+
left: SparkPlan,
38+
right: SparkPlan) extends BinaryNode {
39+
40+
override def output: Seq[Attribute] = left.output ++ right.output
41+
42+
override def outputPartitioning: Partitioning = left.outputPartitioning
43+
44+
override def requiredChildDistribution: Seq[ClusteredOrderedDistribution] =
45+
ClusteredOrderedDistribution(leftKeys) :: ClusteredOrderedDistribution(rightKeys) :: Nil
46+
47+
private val orders: Seq[SortOrder] = leftKeys.map(s => SortOrder(s, Ascending))
48+
private val ordering: RowOrdering = new RowOrdering(orders, left.output)
49+
50+
@transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
51+
@transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)
52+
53+
override def execute() = {
54+
val leftResults = left.execute().map(_.copy())
55+
val rightResults = right.execute().map(_.copy())
56+
57+
leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
58+
new Iterator[Row] {
59+
// Mutable per row objects.
60+
private[this] val joinRow = new JoinedRow5
61+
private[this] var leftElement: Row = _
62+
private[this] var rightElement: Row = _
63+
private[this] var leftKey: Row = _
64+
private[this] var rightKey: Row = _
65+
private[this] var read: Boolean = false
66+
private[this] var currentlMatches: CompactBuffer[Row] = _
67+
private[this] var currentrMatches: CompactBuffer[Row] = _
68+
private[this] var currentlPosition: Int = -1
69+
private[this] var currentrPosition: Int = -1
70+
71+
override final def hasNext: Boolean =
72+
(currentlPosition != -1 && currentlPosition < currentlMatches.size) ||
73+
(leftIter.hasNext && rightIter.hasNext && nextMatchingPair)
74+
75+
override final def next(): Row = {
76+
val joinedRow =
77+
joinRow(currentlMatches(currentlPosition), currentrMatches(currentrPosition))
78+
currentrPosition += 1
79+
if (currentrPosition >= currentrMatches.size) {
80+
currentlPosition += 1
81+
currentrPosition = 0
82+
}
83+
joinedRow
84+
}
85+
86+
/**
87+
* Searches the left/right iterator for the next rows that matches.
88+
*
89+
* @return true if the search is successful, and false if the left/right iterator runs out
90+
* of tuples.
91+
*/
92+
private def nextMatchingPair(): Boolean = {
93+
currentlPosition = -1
94+
currentlMatches = null
95+
if (rightElement == null) {
96+
rightElement = rightIter.next()
97+
rightKey = rightKeyGenerator(rightElement)
98+
}
99+
while (currentlMatches == null && leftIter.hasNext) {
100+
if (!read) {
101+
leftElement = leftIter.next()
102+
leftKey = leftKeyGenerator(leftElement)
103+
}
104+
while (ordering.compare(leftKey, rightKey) > 0 && rightIter.hasNext) {
105+
rightElement = rightIter.next()
106+
rightKey = rightKeyGenerator(rightElement)
107+
}
108+
currentrMatches = new CompactBuffer[Row]()
109+
while (ordering.compare(leftKey, rightKey) == 0 && rightIter.hasNext) {
110+
currentrMatches += rightElement
111+
rightElement = rightIter.next()
112+
rightKey = rightKeyGenerator(rightElement)
113+
}
114+
if (ordering.compare(leftKey, rightKey) == 0) {
115+
currentrMatches += rightElement
116+
}
117+
if (currentrMatches.size > 0) {
118+
// there exists rows match in right table, should search left table
119+
currentlMatches = new CompactBuffer[Row]()
120+
val leftMatch = leftKey.copy()
121+
while (ordering.compare(leftKey, leftMatch) == 0 && leftIter.hasNext) {
122+
currentlMatches += leftElement
123+
leftElement = leftIter.next()
124+
leftKey = leftKeyGenerator(leftElement)
125+
}
126+
if (ordering.compare(leftKey, leftMatch) == 0) {
127+
currentlMatches += leftElement
128+
} else {
129+
read = true
130+
}
131+
}
132+
}
133+
134+
if (currentlMatches == null) {
135+
false
136+
} else {
137+
currentlPosition = 0
138+
currentrPosition = 0
139+
true
140+
}
141+
}
142+
}
143+
}
144+
}
145+
}

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
5151
case j: CartesianProduct => j
5252
case j: BroadcastNestedLoopJoin => j
5353
case j: BroadcastLeftSemiJoinHash => j
54+
case j: SortMergeJoin => j
5455
}
5556

5657
assert(operators.size === 1)
@@ -75,9 +76,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
7576
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
7677
("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
7778
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
78-
("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
79-
("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]),
80-
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]),
79+
("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]),
80+
("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]),
81+
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]),
8182
("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]),
8283
("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
8384
classOf[HashOuterJoin]),

0 commit comments

Comments
 (0)