@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions._
30
30
import org .apache .spark .sql .catalyst .plans .physical .BroadcastMode
31
31
import org .apache .spark .sql .types .LongType
32
32
import org .apache .spark .unsafe .Platform
33
+ import org .apache .spark .unsafe .array .ByteArrayMethods
33
34
import org .apache .spark .unsafe .map .BytesToBytesMap
34
35
import org .apache .spark .util .{KnownSizeEstimation , Utils }
35
36
@@ -362,6 +363,8 @@ private[joins] object UnsafeHashedRelation {
362
363
private [execution] final class LongToUnsafeRowMap (val mm : TaskMemoryManager , capacity : Int )
363
364
extends MemoryConsumer (mm) with Externalizable with KryoSerializable {
364
365
366
+ private val ARRAY_MAX = ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH
367
+
365
368
// Whether the keys are stored in dense mode or not.
366
369
private var isDense = false
367
370
@@ -557,7 +560,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
557
560
def append (key : Long , row : UnsafeRow ): Unit = {
558
561
val sizeInBytes = row.getSizeInBytes
559
562
if (sizeInBytes >= (1 << SIZE_BITS )) {
560
- sys.error (" Does not support row that is larger than 256M" )
563
+ throw new UnsupportedOperationException (" Does not support row that is larger than 256M" )
561
564
}
562
565
563
566
if (key < minKey) {
@@ -567,22 +570,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
567
570
maxKey = key
568
571
}
569
572
570
- // There is 8 bytes for the pointer to next value
571
- val needSize = cursor + 8 + row.getSizeInBytes
572
- val nowSize = page.length * 8L + Platform .LONG_ARRAY_OFFSET
573
- if (needSize > nowSize) {
574
- val used = page.length
575
- if (used >= (1 << 30 )) {
576
- sys.error(" Can not build a HashedRelation that is larger than 8G" )
577
- }
578
- val multiples = math.max(math.ceil(needSize.toDouble / (used * 8L )).toInt, 2 )
579
- ensureAcquireMemory(used * 8L * multiples)
580
- val newPage = new Array [Long ](used * multiples)
581
- Platform .copyMemory(page, Platform .LONG_ARRAY_OFFSET , newPage, Platform .LONG_ARRAY_OFFSET ,
582
- cursor - Platform .LONG_ARRAY_OFFSET )
583
- page = newPage
584
- freeMemory(used * 8L )
585
- }
573
+ grow(row.getSizeInBytes)
586
574
587
575
// copy the bytes of UnsafeRow
588
576
val offset = cursor
@@ -618,7 +606,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
618
606
growArray()
619
607
} else if (numKeys > array.length / 2 * 0.75 ) {
620
608
// The fill ratio should be less than 0.75
621
- sys.error(" Cannot build HashedRelation with more than 1/3 billions unique keys" )
609
+ throw new UnsupportedOperationException (
610
+ " Cannot build HashedRelation with more than 1/3 billions unique keys" )
622
611
}
623
612
}
624
613
} else {
@@ -629,6 +618,32 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
629
618
}
630
619
}
631
620
621
+ private def grow (neededSize : Int ): Unit = {
622
+ // There is 8 bytes for the pointer to next value
623
+ val totalNeededSize = cursor + 8 + neededSize
624
+ val nowSize = page.length * 8L + Platform .LONG_ARRAY_OFFSET
625
+ if (totalNeededSize > nowSize) {
626
+ val used = page.length
627
+ if (used >= (1 << 30 )) {
628
+ throw new UnsupportedOperationException (
629
+ " Can not build a HashedRelation that is larger than 8G" )
630
+ }
631
+ val multiples = math.floor(totalNeededSize.toDouble / nowSize).toInt * 2
632
+ val newLength = used * multiples
633
+ if (newLength > ARRAY_MAX ) {
634
+ throw new UnsupportedOperationException (
635
+ " Cannot grow internal buffer by size " + newLength +
636
+ " because the size after growing " + " exceeds size limitation " + ARRAY_MAX )
637
+ }
638
+ ensureAcquireMemory(newLength * 8L )
639
+ val newPage = new Array [Long ](newLength)
640
+ Platform .copyMemory(page, Platform .LONG_ARRAY_OFFSET , newPage, Platform .LONG_ARRAY_OFFSET ,
641
+ cursor - Platform .LONG_ARRAY_OFFSET )
642
+ page = newPage
643
+ freeMemory(used * 8L )
644
+ }
645
+ }
646
+
632
647
private def growArray (): Unit = {
633
648
var old_array = array
634
649
val n = array.length
0 commit comments