1515 * limitations under the License.
1616 */
1717
18- package org .apache .spark .sql .execution
18+ package org .apache .spark .sql
19+ package execution
1920
20- import scala .collection .mutable
21+ import scala .collection .mutable .{ ArrayBuffer , BitSet }
2122
2223import org .apache .spark .rdd .RDD
2324import org .apache .spark .SparkContext
2425
25- import org . apache . spark . sql . catalyst .errors ._
26- import org . apache . spark . sql . catalyst .expressions ._
27- import org . apache . spark . sql . catalyst .plans ._
28- import org . apache . spark . sql . catalyst .plans .physical .{ClusteredDistribution , Partitioning }
26+ import catalyst .errors ._
27+ import catalyst .expressions ._
28+ import catalyst .plans ._
29+ import catalyst .plans .physical .{ClusteredDistribution , Partitioning }
2930
30- import org .apache .spark .rdd .PartitionLocalRDDFunctions ._
31+ sealed abstract class BuildSide
32+ case object BuildLeft extends BuildSide
33+ case object BuildRight extends BuildSide
3134
32- case class SparkEquiInnerJoin (
35+ object InterpretCondition {
36+ def apply (expression : Expression ): (Row => Boolean ) = {
37+ (r : Row ) => expression.apply(r).asInstanceOf [Boolean ]
38+ }
39+ }
40+
41+ case class HashJoin (
3342 leftKeys : Seq [Expression ],
3443 rightKeys : Seq [Expression ],
44+ buildSide : BuildSide ,
3545 left : SparkPlan ,
3646 right : SparkPlan ) extends BinaryNode {
3747
@@ -40,33 +50,85 @@ case class SparkEquiInnerJoin(
4050 override def requiredChildDistribution =
4151 ClusteredDistribution (leftKeys) :: ClusteredDistribution (rightKeys) :: Nil
4252
53+ val (buildPlan, streamedPlan) = buildSide match {
54+ case BuildLeft => (left, right)
55+ case BuildRight => (right, left)
56+ }
57+
58+ val (buildKeys, streamedKeys) = buildSide match {
59+ case BuildLeft => (leftKeys, rightKeys)
60+ case BuildRight => (rightKeys, leftKeys)
61+ }
62+
4363 def output = left.output ++ right.output
4464
45- def execute () = attachTree(this , " execute" ) {
46- val leftWithKeys = left.execute().mapPartitions { iter =>
47- val generateLeftKeys = new Projection (leftKeys, left.output)
48- iter.map(row => (generateLeftKeys(row), row.copy()))
49- }
65+ @ transient lazy val buildSideKeyGenerator = new Projection (buildKeys, buildPlan.output)
66+ @ transient lazy val streamSideKeyGenerator =
67+ () => new MutableProjection (streamedKeys, streamedPlan.output)
5068
51- val rightWithKeys = right.execute().mapPartitions { iter =>
52- val generateRightKeys = new Projection (rightKeys, right.output)
53- iter.map(row => (generateRightKeys(row), row.copy()))
54- }
69+ def execute () = {
5570
56- // Do the join.
57- val joined = filterNulls(leftWithKeys).joinLocally(filterNulls(rightWithKeys))
58- // Drop join keys and merge input tuples.
59- joined.map { case (_, (leftTuple, rightTuple)) => buildRow(leftTuple ++ rightTuple) }
60- }
71+ buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
72+ val hashTable = new java.util.HashMap [Row , ArrayBuffer [Row ]]()
73+ var currentRow : Row = null
74+
75+ // Create a mapping of buildKeys -> rows
76+ while (buildIter.hasNext) {
77+ currentRow = buildIter.next()
78+ val rowKey = buildSideKeyGenerator(currentRow)
79+ if (! rowKey.anyNull) {
80+ val existingMatchList = hashTable.get(rowKey)
81+ val matchList = if (existingMatchList == null ) {
82+ val newMatchList = new ArrayBuffer [Row ]()
83+ hashTable.put(rowKey, newMatchList)
84+ newMatchList
85+ } else {
86+ existingMatchList
87+ }
88+ matchList += currentRow.copy()
89+ }
90+ }
6191
62- /**
63- * Filters any rows where the any of the join keys is null, ensuring three-valued
64- * logic for the equi-join conditions.
65- */
66- protected def filterNulls (rdd : RDD [(Row , Row )]) =
67- rdd.filter {
68- case (key : Seq [_], _) => ! key.exists(_ == null )
92+ new Iterator [Row ] {
93+ private [this ] var currentRow : Row = _
94+ private [this ] var currentMatches : ArrayBuffer [Row ] = _
95+ private [this ] var currentPosition : Int = - 1
96+
97+ // Mutable per row objects.
98+ private [this ] val joinRow = new JoinedRow
99+
100+ @ transient private val joinKeys = streamSideKeyGenerator()
101+
102+ def hasNext : Boolean =
103+ (currentPosition != - 1 && currentPosition < currentMatches.size) ||
104+ (streamIter.hasNext && fetchNext())
105+
106+ def next () = {
107+ val ret = joinRow(currentRow, currentMatches(currentPosition))
108+ currentPosition += 1
109+ ret
110+ }
111+
112+ private def fetchNext (): Boolean = {
113+ currentMatches = null
114+ currentPosition = - 1
115+
116+ while (currentMatches == null && streamIter.hasNext) {
117+ currentRow = streamIter.next()
118+ if (! joinKeys(currentRow).anyNull)
119+ currentMatches = hashTable.get(joinKeys.currentValue)
120+ }
121+
122+ if (currentMatches == null ) {
123+ false
124+ } else {
125+ currentPosition = 0
126+ true
127+ }
128+ }
129+ }
69130 }
131+ }
70132}
71133
72134case class CartesianProduct (left : SparkPlan , right : SparkPlan ) extends BinaryNode {
@@ -95,17 +157,18 @@ case class BroadcastNestedLoopJoin(
95157 def right = broadcast
96158
97159 @ transient lazy val boundCondition =
98- condition
99- .map(c => BindReferences .bindReference(c, left.output ++ right.output))
100- .getOrElse(Literal (true ))
160+ InterpretCondition (
161+ condition
162+ .map(c => BindReferences .bindReference(c, left.output ++ right.output))
163+ .getOrElse(Literal (true )))
101164
102165
103166 def execute () = {
104167 val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
105168
106169 val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
107- val matchedRows = new mutable. ArrayBuffer [Row ]
108- val includedBroadcastTuples = new mutable.BitSet (broadcastedRelation.value.size)
170+ val matchedRows = new ArrayBuffer [Row ]
171+ val includedBroadcastTuples = new scala.collection. mutable.BitSet (broadcastedRelation.value.size)
109172 val joinedRow = new JoinedRow
110173
111174 streamedIter.foreach { streamedRow =>
@@ -115,7 +178,7 @@ case class BroadcastNestedLoopJoin(
115178 while (i < broadcastedRelation.value.size) {
116179 // TODO: One bitset per partition instead of per row.
117180 val broadcastedRow = broadcastedRelation.value(i)
118- if (boundCondition(joinedRow(streamedRow, broadcastedRow)). asInstanceOf [ Boolean ] ) {
181+ if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
119182 matchedRows += buildRow(streamedRow ++ broadcastedRow)
120183 matched = true
121184 includedBroadcastTuples += i
0 commit comments