Skip to content

Commit 22a2767

Browse files
committed
spliting append func into two parts:grow/append;doubling the size when growing;sys.error instead of UnsupportedOperationException
1 parent d9d8e62 commit 22a2767

File tree

1 file changed

+33
-18
lines changed

1 file changed

+33
-18
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions._
3030
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
3131
import org.apache.spark.sql.types.LongType
3232
import org.apache.spark.unsafe.Platform
33+
import org.apache.spark.unsafe.array.ByteArrayMethods
3334
import org.apache.spark.unsafe.map.BytesToBytesMap
3435
import org.apache.spark.util.{KnownSizeEstimation, Utils}
3536

@@ -362,6 +363,8 @@ private[joins] object UnsafeHashedRelation {
362363
private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, capacity: Int)
363364
extends MemoryConsumer(mm) with Externalizable with KryoSerializable {
364365

366+
private val ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
367+
365368
// Whether the keys are stored in dense mode or not.
366369
private var isDense = false
367370

@@ -557,7 +560,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
557560
def append(key: Long, row: UnsafeRow): Unit = {
558561
val sizeInBytes = row.getSizeInBytes
559562
if (sizeInBytes >= (1 << SIZE_BITS)) {
560-
sys.error("Does not support row that is larger than 256M")
563+
throw new UnsupportedOperationException("Does not support row that is larger than 256M")
561564
}
562565

563566
if (key < minKey) {
@@ -567,22 +570,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
567570
maxKey = key
568571
}
569572

570-
// There is 8 bytes for the pointer to next value
571-
val needSize = cursor + 8 + row.getSizeInBytes
572-
val nowSize = page.length * 8L + Platform.LONG_ARRAY_OFFSET
573-
if (needSize > nowSize) {
574-
val used = page.length
575-
if (used >= (1 << 30)) {
576-
sys.error("Can not build a HashedRelation that is larger than 8G")
577-
}
578-
val multiples = math.max(math.ceil(needSize.toDouble / (used * 8L)).toInt, 2)
579-
ensureAcquireMemory(used * 8L * multiples)
580-
val newPage = new Array[Long](used * multiples)
581-
Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
582-
cursor - Platform.LONG_ARRAY_OFFSET)
583-
page = newPage
584-
freeMemory(used * 8L)
585-
}
573+
grow(row.getSizeInBytes)
586574

587575
// copy the bytes of UnsafeRow
588576
val offset = cursor
@@ -618,7 +606,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
618606
growArray()
619607
} else if (numKeys > array.length / 2 * 0.75) {
620608
// The fill ratio should be less than 0.75
621-
sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys")
609+
throw new UnsupportedOperationException(
610+
"Cannot build HashedRelation with more than 1/3 billions unique keys")
622611
}
623612
}
624613
} else {
@@ -629,6 +618,32 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
629618
}
630619
}
631620

621+
private def grow(neededSize: Int): Unit = {
622+
// There is 8 bytes for the pointer to next value
623+
val totalNeededSize = cursor + 8 + neededSize
624+
val nowSize = page.length * 8L + Platform.LONG_ARRAY_OFFSET
625+
if (totalNeededSize > nowSize) {
626+
val used = page.length
627+
if (used >= (1 << 30)) {
628+
throw new UnsupportedOperationException(
629+
"Can not build a HashedRelation that is larger than 8G")
630+
}
631+
val multiples = math.floor(totalNeededSize.toDouble / nowSize).toInt * 2
632+
val newLength = used * multiples
633+
if (newLength > ARRAY_MAX) {
634+
throw new UnsupportedOperationException(
635+
"Cannot grow internal buffer by size " + newLength +
636+
" because the size after growing " + "exceeds size limitation " + ARRAY_MAX)
637+
}
638+
ensureAcquireMemory(newLength * 8L)
639+
val newPage = new Array[Long](newLength)
640+
Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
641+
cursor - Platform.LONG_ARRAY_OFFSET)
642+
page = newPage
643+
freeMemory(used * 8L)
644+
}
645+
}
646+
632647
private def growArray(): Unit = {
633648
var old_array = array
634649
val n = array.length

0 commit comments

Comments
 (0)