Skip to content

Commit 8d40a79

Browse files
kiszkhvanhovell
authored andcommitted
[SPARK-23893][CORE][SQL] Avoid possible integer overflow in multiplication
## What changes were proposed in this pull request? This PR avoids possible overflow at an operation `long = (long)(int * int)`. The multiplication of large positive integer values may set one to MSB. This leads to a negative value in long while we expected a positive value (e.g. `0111_0000_0000_0000 * 0000_0000_0000_0010`). This PR performs long cast before the multiplication to avoid this situation. ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com> Closes #21002 from kiszk/SPARK-23893.
1 parent 710a68c commit 8d40a79

File tree

11 files changed

+15
-14
lines changed

11 files changed

+15
-14
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public UnsafeInMemorySorter(
124124
int initialSize,
125125
boolean canUseRadixSort) {
126126
this(consumer, memoryManager, recordComparator, prefixComparator,
127-
consumer.allocateArray(initialSize * 2), canUseRadixSort);
127+
consumer.allocateArray(initialSize * 2L), canUseRadixSort);
128128
}
129129

130130
public UnsafeInMemorySorter(

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int
8484

8585
@Override
8686
public LongArray allocate(int length) {
87-
assert (length * 2 <= buffer.size()) :
87+
assert (length * 2L <= buffer.size()) :
8888
"the buffer is smaller than required: " + buffer.size() + " < " + (length * 2);
8989
return buffer;
9090
}

core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
9090
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
9191
// by 50%. We also cap the estimation in the end.
9292
if (results.size == 0) {
93-
numPartsToTry = partsScanned * 4
93+
numPartsToTry = partsScanned * 4L
9494
} else {
9595
// the left side of max is >=1 whenever partsScanned >= 2
9696
numPartsToTry = Math.max(1,

core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
135135
// This job runs 2 stages, and we're in the second stage. Therefore, any task attempt
136136
// ID that's < 2 * numPartitions belongs to the first attempt of this stage.
137137
val taskContext = TaskContext.get()
138-
val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2
138+
val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2L
139139
if (isFirstStageAttempt) {
140140
throw new FetchFailedException(
141141
SparkEnv.get.blockManager.blockManagerId,

core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
382382
val log = newLogFile("downloadApp1", Some(s"attempt$i"), inProgress = false)
383383
writeFile(log, true, None,
384384
SparkListenerApplicationStart(
385-
"downloadApp1", Some("downloadApp1"), 5000 * i, "test", Some(s"attempt$i")),
386-
SparkListenerApplicationEnd(5001 * i)
385+
"downloadApp1", Some("downloadApp1"), 5000L * i, "test", Some(s"attempt$i")),
386+
SparkListenerApplicationEnd(5001L * i)
387387
)
388388
log
389389
}

core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ class JsonProtocolSuite extends SparkFunSuite {
317317
test("SparkListenerJobStart backward compatibility") {
318318
// Prior to Spark 1.2.0, SparkListenerJobStart did not have a "Stage Infos" property.
319319
val stageIds = Seq[Int](1, 2, 3, 4)
320-
val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400, x * 500))
320+
val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400L, x * 500L))
321321
val dummyStageInfos =
322322
stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown"))
323323
val jobStart = SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties)
@@ -331,7 +331,7 @@ class JsonProtocolSuite extends SparkFunSuite {
331331
// Prior to Spark 1.3.0, SparkListenerJobStart did not have a "Submission Time" property.
332332
// Also, SparkListenerJobEnd did not have a "Completion Time" property.
333333
val stageIds = Seq[Int](1, 2, 3, 4)
334-
val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40, x * 50))
334+
val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40L, x * 50L))
335335
val jobStart = SparkListenerJobStart(11, jobSubmissionTime, stageInfos, properties)
336336
val oldStartEvent = JsonProtocol.jobStartToJson(jobStart)
337337
.removeField({ _._1 == "Submission Time"})

sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ object HashBenchmark {
4040
safeProjection(encoder.toRow(generator().asInstanceOf[Row])).copy()
4141
).toArray
4242

43-
val benchmark = new Benchmark("Hash For " + name, iters * numRows)
43+
val benchmark = new Benchmark("Hash For " + name, iters * numRows.toLong)
4444
benchmark.addCase("interpreted version") { _: Int =>
4545
var sum = 0
4646
for (_ <- 0L until iters) {

sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ object HashByteArrayBenchmark {
3636
bytes
3737
}
3838

39-
val benchmark = new Benchmark("Hash byte arrays with length " + length, iters * numArrays)
39+
val benchmark =
40+
new Benchmark("Hash byte arrays with length " + length, iters * numArrays.toLong)
4041
benchmark.addCase("Murmur3_x86_32") { _: Int =>
4142
var sum = 0L
4243
for (_ <- 0L until iters) {

sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ object UnsafeProjectionBenchmark {
3838
val iters = 1024 * 16
3939
val numRows = 1024 * 16
4040

41-
val benchmark = new Benchmark("unsafe projection", iters * numRows)
41+
val benchmark = new Benchmark("unsafe projection", iters * numRows.toLong)
4242

4343

4444
val schema1 = new StructType().add("l", LongType, false)

sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes {
7777
count: Int,
7878
tpe: NativeColumnType[T],
7979
input: ByteBuffer): Unit = {
80-
val benchmark = new Benchmark(name, iters * count)
80+
val benchmark = new Benchmark(name, iters * count.toLong)
8181

8282
schemes.filter(_.supports(tpe)).foreach { scheme =>
8383
val (compressFunc, compressionRatio, buf) = prepareEncodeInternal(count, tpe, scheme, input)
@@ -101,7 +101,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes {
101101
count: Int,
102102
tpe: NativeColumnType[T],
103103
input: ByteBuffer): Unit = {
104-
val benchmark = new Benchmark(name, iters * count)
104+
val benchmark = new Benchmark(name, iters * count.toLong)
105105

106106
schemes.filter(_.supports(tpe)).foreach { scheme =>
107107
val (compressFunc, _, buf) = prepareEncodeInternal(count, tpe, scheme, input)

sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ object ColumnarBatchBenchmark {
295295

296296
def booleanAccess(iters: Int): Unit = {
297297
val count = 8 * 1024
298-
val benchmark = new Benchmark("Boolean Read/Write", iters * count)
298+
val benchmark = new Benchmark("Boolean Read/Write", iters * count.toLong)
299299
benchmark.addCase("Bitset") { i: Int => {
300300
val b = new BitSet(count)
301301
var sum = 0L

0 commit comments

Comments
 (0)