Skip to content

Commit bc0cb84

Browse files
committed
Rewrite join implementation to allow streaming of one relation.
1 parent 1fa48d9 commit bc0cb84

File tree

5 files changed

+113
-40
lines changed

5 files changed

+113
-40
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ trait Row extends Seq[Any] with Serializable {
4444
s"[${this.mkString(",")}]"
4545

4646
def copy(): Row
47+
48+
/** Returns true if there are any NULL values in this row. */
49+
def anyNull: Boolean = {
50+
var i = 0
51+
while(i < length) {
52+
if(isNullAt(i)) return true
53+
i += 1
54+
}
55+
false
56+
}
4757
}
4858

4959
/**

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
117117
val strategies: Seq[Strategy] =
118118
TopK ::
119119
PartialAggregation ::
120-
SparkEquiInnerJoin ::
120+
HashJoin ::
121121
ParquetOperations ::
122122
BasicOperators ::
123123
CartesianProduct ::

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.parquet._
2828
abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
2929
self: SQLContext#SparkPlanner =>
3030

31-
object SparkEquiInnerJoin extends Strategy {
31+
object HashJoin extends Strategy {
3232
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
3333
case FilteredOperation(predicates, logical.Join(left, right, Inner, condition)) =>
3434
logger.debug(s"Considering join: ${predicates ++ condition}")
@@ -51,8 +51,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
5151
val leftKeys = joinKeys.map(_._1)
5252
val rightKeys = joinKeys.map(_._2)
5353

54-
val joinOp = execution.SparkEquiInnerJoin(
55-
leftKeys, rightKeys, planLater(left), planLater(right))
54+
val joinOp = execution.HashJoin(
55+
leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))
5656

5757
// Make sure other conditions are met if present.
5858
if (otherPredicates.nonEmpty) {

sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala

Lines changed: 98 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,33 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.sql.execution
18+
package org.apache.spark.sql
19+
package execution
1920

20-
import scala.collection.mutable
21+
import scala.collection.mutable.{ArrayBuffer, BitSet}
2122

2223
import org.apache.spark.rdd.RDD
2324
import org.apache.spark.SparkContext
2425

25-
import org.apache.spark.sql.catalyst.errors._
26-
import org.apache.spark.sql.catalyst.expressions._
27-
import org.apache.spark.sql.catalyst.plans._
28-
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
26+
import catalyst.errors._
27+
import catalyst.expressions._
28+
import catalyst.plans._
29+
import catalyst.plans.physical.{ClusteredDistribution, Partitioning}
2930

30-
import org.apache.spark.rdd.PartitionLocalRDDFunctions._
31+
sealed abstract class BuildSide
32+
case object BuildLeft extends BuildSide
33+
case object BuildRight extends BuildSide
3134

32-
case class SparkEquiInnerJoin(
35+
object InterpretCondition {
36+
def apply(expression: Expression): (Row => Boolean) = {
37+
(r: Row) => expression.apply(r).asInstanceOf[Boolean]
38+
}
39+
}
40+
41+
case class HashJoin(
3342
leftKeys: Seq[Expression],
3443
rightKeys: Seq[Expression],
44+
buildSide: BuildSide,
3545
left: SparkPlan,
3646
right: SparkPlan) extends BinaryNode {
3747

@@ -40,33 +50,85 @@ case class SparkEquiInnerJoin(
4050
override def requiredChildDistribution =
4151
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
4252

53+
val (buildPlan, streamedPlan) = buildSide match {
54+
case BuildLeft => (left, right)
55+
case BuildRight => (right, left)
56+
}
57+
58+
val (buildKeys, streamedKeys) = buildSide match {
59+
case BuildLeft => (leftKeys, rightKeys)
60+
case BuildRight => (rightKeys, leftKeys)
61+
}
62+
4363
def output = left.output ++ right.output
4464

45-
def execute() = attachTree(this, "execute") {
46-
val leftWithKeys = left.execute().mapPartitions { iter =>
47-
val generateLeftKeys = new Projection(leftKeys, left.output)
48-
iter.map(row => (generateLeftKeys(row), row.copy()))
49-
}
65+
@transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
66+
@transient lazy val streamSideKeyGenerator =
67+
() => new MutableProjection(streamedKeys, streamedPlan.output)
5068

51-
val rightWithKeys = right.execute().mapPartitions { iter =>
52-
val generateRightKeys = new Projection(rightKeys, right.output)
53-
iter.map(row => (generateRightKeys(row), row.copy()))
54-
}
69+
def execute() = {
5570

56-
// Do the join.
57-
val joined = filterNulls(leftWithKeys).joinLocally(filterNulls(rightWithKeys))
58-
// Drop join keys and merge input tuples.
59-
joined.map { case (_, (leftTuple, rightTuple)) => buildRow(leftTuple ++ rightTuple) }
60-
}
71+
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
72+
val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
73+
var currentRow: Row = null
74+
75+
// Create a mapping of buildKeys -> rows
76+
while(buildIter.hasNext) {
77+
currentRow = buildIter.next()
78+
val rowKey = buildSideKeyGenerator(currentRow)
79+
if(!rowKey.anyNull) {
80+
val existingMatchList = hashTable.get(rowKey)
81+
val matchList = if (existingMatchList == null) {
82+
val newMatchList = new ArrayBuffer[Row]()
83+
hashTable.put(rowKey, newMatchList)
84+
newMatchList
85+
} else {
86+
existingMatchList
87+
}
88+
matchList += currentRow.copy()
89+
}
90+
}
6191

62-
/**
63-
* Filters any rows where the any of the join keys is null, ensuring three-valued
64-
* logic for the equi-join conditions.
65-
*/
66-
protected def filterNulls(rdd: RDD[(Row, Row)]) =
67-
rdd.filter {
68-
case (key: Seq[_], _) => !key.exists(_ == null)
92+
new Iterator[Row] {
93+
private[this] var currentRow: Row = _
94+
private[this] var currentMatches: ArrayBuffer[Row] = _
95+
private[this] var currentPosition: Int = -1
96+
97+
// Mutable per row objects.
98+
private[this] val joinRow = new JoinedRow
99+
100+
@transient private val joinKeys = streamSideKeyGenerator()
101+
102+
def hasNext: Boolean =
103+
(currentPosition != -1 && currentPosition < currentMatches.size) ||
104+
(streamIter.hasNext && fetchNext())
105+
106+
def next() = {
107+
val ret = joinRow(currentRow, currentMatches(currentPosition))
108+
currentPosition += 1
109+
ret
110+
}
111+
112+
private def fetchNext(): Boolean = {
113+
currentMatches = null
114+
currentPosition = -1
115+
116+
while (currentMatches == null && streamIter.hasNext) {
117+
currentRow = streamIter.next()
118+
if(!joinKeys(currentRow).anyNull)
119+
currentMatches = hashTable.get(joinKeys.currentValue)
120+
}
121+
122+
if (currentMatches == null) {
123+
false
124+
} else {
125+
currentPosition = 0
126+
true
127+
}
128+
}
129+
}
69130
}
131+
}
70132
}
71133

72134
case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
@@ -95,17 +157,18 @@ case class BroadcastNestedLoopJoin(
95157
def right = broadcast
96158

97159
@transient lazy val boundCondition =
98-
condition
99-
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
100-
.getOrElse(Literal(true))
160+
InterpretCondition(
161+
condition
162+
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
163+
.getOrElse(Literal(true)))
101164

102165

103166
def execute() = {
104167
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
105168

106169
val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
107-
val matchedRows = new mutable.ArrayBuffer[Row]
108-
val includedBroadcastTuples = new mutable.BitSet(broadcastedRelation.value.size)
170+
val matchedRows = new ArrayBuffer[Row]
171+
val includedBroadcastTuples = new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
109172
val joinedRow = new JoinedRow
110173

111174
streamedIter.foreach { streamedRow =>
@@ -115,7 +178,7 @@ case class BroadcastNestedLoopJoin(
115178
while (i < broadcastedRelation.value.size) {
116179
// TODO: One bitset per partition instead of per row.
117180
val broadcastedRow = broadcastedRelation.value(i)
118-
if (boundCondition(joinedRow(streamedRow, broadcastedRow)).asInstanceOf[Boolean]) {
181+
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
119182
matchedRows += buildRow(streamedRow ++ broadcastedRow)
120183
matched = true
121184
includedBroadcastTuples += i

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
194194
DataSinks,
195195
Scripts,
196196
PartialAggregation,
197-
SparkEquiInnerJoin,
197+
HashJoin,
198198
BasicOperators,
199199
CartesianProduct,
200200
BroadcastNestedLoopJoin

0 commit comments

Comments
 (0)