Skip to content

Commit c227732

Browse files
cxzl25rdblue
authored andcommitted
[SPARK-24257][SQL] LongToUnsafeRowMap calculate the new size may be wrong
LongToUnsafeRowMap has a mistake when growing its page array: it blindly grows to `oldSize * 2`, while the new record may be larger than `oldSize * 2`. Then we may have a malformed UnsafeRow when querying this map, whose actual data is smaller than its declared size, and the data is corrupted. Author: sychen <sychen@ctrip.com> Closes apache#21311 from cxzl25/fix_LongToUnsafeRowMap_page_size.
1 parent d664651 commit c227732

File tree

2 files changed

+48
-16
lines changed

2 files changed

+48
-16
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
533533
def append(key: Long, row: UnsafeRow): Unit = {
534534
val sizeInBytes = row.getSizeInBytes
535535
if (sizeInBytes >= (1 << SIZE_BITS)) {
536-
sys.error("Does not support row that is larger than 256M")
536+
throw new UnsupportedOperationException("Does not support row that is larger than 256M")
537537
}
538538

539539
if (key < minKey) {
@@ -543,19 +543,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
543543
maxKey = key
544544
}
545545

546-
// There is 8 bytes for the pointer to next value
547-
if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET) {
548-
val used = page.length
549-
if (used >= (1 << 30)) {
550-
sys.error("Can not build a HashedRelation that is larger than 8G")
551-
}
552-
ensureAcquireMemory(used * 8L * 2)
553-
val newPage = new Array[Long](used * 2)
554-
Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
555-
cursor - Platform.LONG_ARRAY_OFFSET)
556-
page = newPage
557-
freeMemory(used * 8L)
558-
}
546+
grow(row.getSizeInBytes)
559547

560548
// copy the bytes of UnsafeRow
561549
val offset = cursor
@@ -588,7 +576,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
588576
growArray()
589577
} else if (numKeys > array.length / 2 * 0.75) {
590578
// The fill ratio should be less than 0.75
591-
sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys")
579+
throw new UnsupportedOperationException(
580+
"Cannot build HashedRelation with more than 1/3 billions unique keys")
592581
}
593582
}
594583
} else {
@@ -599,6 +588,25 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
599588
}
600589
}
601590

591+
private def grow(inputRowSize: Int): Unit = {
592+
// There is 8 bytes for the pointer to next value
593+
val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8
594+
if (neededNumWords > page.length) {
595+
if (neededNumWords > (1 << 30)) {
596+
throw new UnsupportedOperationException(
597+
"Can not build a HashedRelation that is larger than 8G")
598+
}
599+
val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30))
600+
ensureAcquireMemory(newNumWords * 8L)
601+
val newPage = new Array[Long](newNumWords.toInt)
602+
Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
603+
cursor - Platform.LONG_ARRAY_OFFSET)
604+
val used = page.length
605+
page = newPage
606+
freeMemory(used * 8L)
607+
}
608+
}
609+
602610
private def growArray(): Unit = {
603611
var old_array = array
604612
val n = array.length

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.serializer.KryoSerializer
2727
import org.apache.spark.sql.catalyst.InternalRow
2828
import org.apache.spark.sql.catalyst.expressions._
2929
import org.apache.spark.sql.test.SharedSQLContext
30-
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType}
30+
import org.apache.spark.sql.types._
3131
import org.apache.spark.unsafe.map.BytesToBytesMap
3232
import org.apache.spark.unsafe.types.UTF8String
3333
import org.apache.spark.util.collection.CompactBuffer
@@ -253,6 +253,30 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
253253
map.free()
254254
}
255255

256+
test("SPARK-24257: insert big values into LongToUnsafeRowMap") {
257+
val taskMemoryManager = new TaskMemoryManager(
258+
new StaticMemoryManager(
259+
new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
260+
Long.MaxValue,
261+
Long.MaxValue,
262+
1),
263+
0)
264+
val unsafeProj = UnsafeProjection.create(Array[DataType](StringType))
265+
val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
266+
267+
val key = 0L
268+
// the page array is initialized with length 1 << 17 (1M bytes),
269+
// so here we need a value larger than 1 << 18 (2M bytes), to trigger the bug
270+
val bigStr = UTF8String.fromString("x" * (1 << 19))
271+
272+
map.append(key, unsafeProj(InternalRow(bigStr)))
273+
map.optimize()
274+
275+
val resultRow = new UnsafeRow(1)
276+
assert(map.getValue(key, resultRow).getUTF8String(0) === bigStr)
277+
map.free()
278+
}
279+
256280
test("Spark-14521") {
257281
val ser = new KryoSerializer(
258282
(new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()

0 commit comments

Comments
 (0)