Skip to content

Commit 4bf3de7

Browse files
mateizmarmbrus
authored andcommitted
[SPARK-3085] [SQL] Use compact data structures in SQL joins
This reuses the CompactBuffer from Spark Core to save memory and pointer dereferences. I also tried AppendOnlyMap instead of java.util.HashMap but unfortunately that slows things down because it seems to do more equals() calls and the equals on GenericRow, and especially JoinedRow, is pretty expensive. Author: Matei Zaharia <matei@databricks.com> Closes #1993 from mateiz/spark-3085 and squashes the following commits: 188221e [Matei Zaharia] Remove unneeded import 5f903ee [Matei Zaharia] [SPARK-3085] [SQL] Use compact data structures in SQL joins
1 parent 6a13dca commit 4bf3de7

File tree

1 file changed

+33
-34
lines changed
  • sql/core/src/main/scala/org/apache/spark/sql/execution

1 file changed

+33
-34
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,15 @@ package org.apache.spark.sql.execution
1919

2020
import java.util.{HashMap => JavaHashMap}
2121

22-
import scala.collection.mutable.{ArrayBuffer, BitSet}
2322
import scala.concurrent.ExecutionContext.Implicits.global
2423
import scala.concurrent._
2524
import scala.concurrent.duration._
2625

2726
import org.apache.spark.annotation.DeveloperApi
28-
import org.apache.spark.sql.SQLContext
2927
import org.apache.spark.sql.catalyst.expressions._
3028
import org.apache.spark.sql.catalyst.plans._
3129
import org.apache.spark.sql.catalyst.plans.physical._
30+
import org.apache.spark.util.collection.CompactBuffer
3231

3332
@DeveloperApi
3433
sealed abstract class BuildSide
@@ -67,7 +66,7 @@ trait HashJoin {
6766
def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = {
6867
// TODO: Use Spark's HashMap implementation.
6968

70-
val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
69+
val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]()
7170
var currentRow: Row = null
7271

7372
// Create a mapping of buildKeys -> rows
@@ -77,7 +76,7 @@ trait HashJoin {
7776
if (!rowKey.anyNull) {
7877
val existingMatchList = hashTable.get(rowKey)
7978
val matchList = if (existingMatchList == null) {
80-
val newMatchList = new ArrayBuffer[Row]()
79+
val newMatchList = new CompactBuffer[Row]()
8180
hashTable.put(rowKey, newMatchList)
8281
newMatchList
8382
} else {
@@ -89,7 +88,7 @@ trait HashJoin {
8988

9089
new Iterator[Row] {
9190
private[this] var currentStreamedRow: Row = _
92-
private[this] var currentHashMatches: ArrayBuffer[Row] = _
91+
private[this] var currentHashMatches: CompactBuffer[Row] = _
9392
private[this] var currentMatchPosition: Int = -1
9493

9594
// Mutable per row objects.
@@ -140,7 +139,7 @@ trait HashJoin {
140139

141140
/**
142141
* :: DeveloperApi ::
143-
* Performs a hash based outer join for two child relations by shuffling the data using
142+
* Performs a hash based outer join for two child relations by shuffling the data using
144143
* the join keys. This operator requires loading the associated partition in both side into memory.
145144
*/
146145
@DeveloperApi
@@ -179,26 +178,26 @@ case class HashOuterJoin(
179178
@transient private[this] lazy val EMPTY_LIST = Seq.empty[Row]
180179

181180
// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
182-
// iterator for performance purpose.
181+
// iterator for performance purpose.
183182

184183
private[this] def leftOuterIterator(
185184
key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
186185
val joinedRow = new JoinedRow()
187186
val rightNullRow = new GenericRow(right.output.length)
188-
val boundCondition =
187+
val boundCondition =
189188
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
190189

191-
leftIter.iterator.flatMap { l =>
190+
leftIter.iterator.flatMap { l =>
192191
joinedRow.withLeft(l)
193192
var matched = false
194-
(if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) =>
193+
(if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) =>
195194
matched = true
196195
joinedRow.copy
197196
} else {
198197
Nil
199198
}) ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
200199
// DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
201-
// as we don't know whether we need to append it until finish iterating all of the
200+
// as we don't know whether we need to append it until finish iterating all of the
202201
// records in right side.
203202
// If we didn't get any proper row, then append a single row with empty right
204203
joinedRow.withRight(rightNullRow).copy
@@ -210,20 +209,20 @@ case class HashOuterJoin(
210209
key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
211210
val joinedRow = new JoinedRow()
212211
val leftNullRow = new GenericRow(left.output.length)
213-
val boundCondition =
212+
val boundCondition =
214213
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
215214

216-
rightIter.iterator.flatMap { r =>
215+
rightIter.iterator.flatMap { r =>
217216
joinedRow.withRight(r)
218217
var matched = false
219-
(if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) =>
218+
(if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) =>
220219
matched = true
221220
joinedRow.copy
222221
} else {
223222
Nil
224223
}) ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
225224
// DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
226-
// as we don't know whether we need to append it until finish iterating all of the
225+
// as we don't know whether we need to append it until finish iterating all of the
227226
// records in left side.
228227
// If we didn't get any proper row, then append a single row with empty left.
229228
joinedRow.withLeft(leftNullRow).copy
@@ -236,7 +235,7 @@ case class HashOuterJoin(
236235
val joinedRow = new JoinedRow()
237236
val leftNullRow = new GenericRow(left.output.length)
238237
val rightNullRow = new GenericRow(right.output.length)
239-
val boundCondition =
238+
val boundCondition =
240239
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
241240

242241
if (!key.anyNull) {
@@ -246,8 +245,8 @@ case class HashOuterJoin(
246245
leftIter.iterator.flatMap[Row] { l =>
247246
joinedRow.withLeft(l)
248247
var matched = false
249-
rightIter.zipWithIndex.collect {
250-
// 1. For those matched (satisfy the join condition) records with both sides filled,
248+
rightIter.zipWithIndex.collect {
249+
// 1. For those matched (satisfy the join condition) records with both sides filled,
251250
// append them directly
252251

253252
case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> {
@@ -260,16 +259,16 @@ case class HashOuterJoin(
260259
// 2. For those unmatched records in left, append additional records with empty right.
261260

262261
// DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
263-
// as we don't know whether we need to append it until finish iterating all
262+
// as we don't know whether we need to append it until finish iterating all
264263
// of the records in right side.
265264
// If we didn't get any proper row, then append a single row with empty right.
266265
joinedRow.withRight(rightNullRow).copy
267266
})
268267
} ++ rightIter.zipWithIndex.collect {
269268
// 3. For those unmatched records in right, append additional records with empty left.
270269

271-
// Re-visiting the records in right, and append additional row with empty left, if its not
272-
// in the matched set.
270+
// Re-visiting the records in right, and append additional row with empty left, if its not
271+
// in the matched set.
273272
case (r, idx) if (!rightMatchedSet.contains(idx)) => {
274273
joinedRow(leftNullRow, r).copy
275274
}
@@ -284,15 +283,15 @@ case class HashOuterJoin(
284283
}
285284

286285
private[this] def buildHashTable(
287-
iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, ArrayBuffer[Row]] = {
288-
val hashTable = new JavaHashMap[Row, ArrayBuffer[Row]]()
286+
iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = {
287+
val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]()
289288
while (iter.hasNext) {
290289
val currentRow = iter.next()
291290
val rowKey = keyGenerator(currentRow)
292291

293292
var existingMatchList = hashTable.get(rowKey)
294293
if (existingMatchList == null) {
295-
existingMatchList = new ArrayBuffer[Row]()
294+
existingMatchList = new CompactBuffer[Row]()
296295
hashTable.put(rowKey, existingMatchList)
297296
}
298297

@@ -311,20 +310,20 @@ case class HashOuterJoin(
311310
val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
312311

313312
import scala.collection.JavaConversions._
314-
val boundCondition =
313+
val boundCondition =
315314
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
316315
joinType match {
317316
case LeftOuter => leftHashTable.keysIterator.flatMap { key =>
318-
leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
317+
leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
319318
rightHashTable.getOrElse(key, EMPTY_LIST))
320319
}
321320
case RightOuter => rightHashTable.keysIterator.flatMap { key =>
322-
rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
321+
rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
323322
rightHashTable.getOrElse(key, EMPTY_LIST))
324323
}
325324
case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
326-
fullOuterIterator(key,
327-
leftHashTable.getOrElse(key, EMPTY_LIST),
325+
fullOuterIterator(key,
326+
leftHashTable.getOrElse(key, EMPTY_LIST),
328327
rightHashTable.getOrElse(key, EMPTY_LIST))
329328
}
330329
case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
@@ -550,7 +549,7 @@ case class BroadcastNestedLoopJoin(
550549

551550
/** All rows that either match both-way, or rows from streamed joined with nulls. */
552551
val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter =>
553-
val matchedRows = new ArrayBuffer[Row]
552+
val matchedRows = new CompactBuffer[Row]
554553
// TODO: Use Spark's BitSet.
555554
val includedBroadcastTuples =
556555
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
@@ -602,20 +601,20 @@ case class BroadcastNestedLoopJoin(
602601
val rightNulls = new GenericMutableRow(right.output.size)
603602
/** Rows from broadcasted joined with nulls. */
604603
val broadcastRowsWithNulls: Seq[Row] = {
605-
val arrBuf: collection.mutable.ArrayBuffer[Row] = collection.mutable.ArrayBuffer()
604+
val buf: CompactBuffer[Row] = new CompactBuffer()
606605
var i = 0
607606
val rel = broadcastedRelation.value
608607
while (i < rel.length) {
609608
if (!allIncludedBroadcastTuples.contains(i)) {
610609
(joinType, buildSide) match {
611-
case (RightOuter | FullOuter, BuildRight) => arrBuf += new JoinedRow(leftNulls, rel(i))
612-
case (LeftOuter | FullOuter, BuildLeft) => arrBuf += new JoinedRow(rel(i), rightNulls)
610+
case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i))
611+
case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls)
613612
case _ =>
614613
}
615614
}
616615
i += 1
617616
}
618-
arrBuf.toSeq
617+
buf.toSeq
619618
}
620619

621620
// TODO: Breaks lineage.

0 commit comments

Comments
 (0)