Skip to content

Commit ae6820a

Browse files
committed
replace int with long in RadixSort.java
1 parent 472c0c3 commit ae6820a

File tree

5 files changed

+51
-51
lines changed

5 files changed

+51
-51
lines changed

core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,9 @@ public ShuffleSorterIterator getSortedIterator() {
176176
int offset = 0;
177177
if (useRadixSort) {
178178
offset = RadixSort.sort(
179-
array, pos,
180-
PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX,
181-
PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false);
179+
array, (long)pos,
180+
(long)PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX,
181+
(long)PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false);
182182
} else {
183183
MemoryBlock unused = new MemoryBlock(
184184
array.getBaseObject(),

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -40,28 +40,28 @@ public class RadixSort {
4040
* of always copying the data back to position zero for efficiency.
4141
*/
4242
public static int sort(
43-
LongArray array, int numRecords, int startByteIndex, int endByteIndex,
43+
LongArray array, long numRecords, long startByteIndex, long endByteIndex,
4444
boolean desc, boolean signed) {
4545
assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
4646
assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
4747
assert endByteIndex > startByteIndex;
4848
assert numRecords * 2 <= array.size();
49-
int inIndex = 0;
50-
int outIndex = numRecords;
49+
long inIndex = 0;
50+
long outIndex = numRecords;
5151
if (numRecords > 0) {
5252
long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex);
53-
for (int i = startByteIndex; i <= endByteIndex; i++) {
54-
if (counts[i] != null) {
53+
for (long i = startByteIndex; i <= endByteIndex; i++) {
54+
if (counts[(int)i] != null) {
5555
sortAtByte(
56-
array, numRecords, counts[i], i, inIndex, outIndex,
56+
array, numRecords, counts[(int)i], i, inIndex, outIndex,
5757
desc, signed && i == endByteIndex);
58-
int tmp = inIndex;
58+
long tmp = inIndex;
5959
inIndex = outIndex;
6060
outIndex = tmp;
6161
}
6262
}
6363
}
64-
return inIndex;
64+
return (int)inIndex;
6565
}
6666

6767
/**
@@ -78,7 +78,7 @@ public static int sort(
7878
* @param signed whether this is a signed (two's complement) sort (only applies to last byte).
7979
*/
8080
private static void sortAtByte(
81-
LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
81+
LongArray array, long numRecords, long[] counts, long byteIdx, long inIndex, long outIndex,
8282
boolean desc, boolean signed) {
8383
assert counts.length == 256;
8484
long[] offsets = transformCountsToOffsets(
@@ -106,7 +106,7 @@ private static void sortAtByte(
106106
* significant byte. If the byte does not need sorting the array will be null.
107107
*/
108108
private static long[][] getCounts(
109-
LongArray array, int numRecords, int startByteIndex, int endByteIndex) {
109+
LongArray array, long numRecords, long startByteIndex, long endByteIndex) {
110110
long[][] counts = new long[8][];
111111
// Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting.
112112
// If all the byte values at a particular index are the same we don't need to count it.
@@ -121,12 +121,12 @@ private static long[][] getCounts(
121121
}
122122
long bitsChanged = bitwiseMin ^ bitwiseMax;
123123
// Compute counts for each byte index.
124-
for (int i = startByteIndex; i <= endByteIndex; i++) {
124+
for (long i = startByteIndex; i <= endByteIndex; i++) {
125125
if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
126-
counts[i] = new long[256];
126+
counts[(int)i] = new long[256];
127127
// TODO(ekl) consider computing all the counts in one pass.
128128
for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
129-
counts[i][(int)((Platform.getLong(baseObject, offset) >>> (i * 8)) & 0xff)]++;
129+
counts[(int)i][(int)((Platform.getLong(baseObject, offset) >>> (i * 8)) & 0xff)]++;
130130
}
131131
}
132132
}
@@ -146,7 +146,7 @@ private static long[][] getCounts(
146146
* @return the input counts array.
147147
*/
148148
private static long[] transformCountsToOffsets(
149-
long[] counts, int numRecords, long outputOffset, int bytesPerRecord,
149+
long[] counts, long numRecords, long outputOffset, long bytesPerRecord,
150150
boolean desc, boolean signed) {
151151
assert counts.length == 256;
152152
int start = signed ? 128 : 0; // output the negative records first (values 129-255).
@@ -176,41 +176,41 @@ private static long[] transformCountsToOffsets(
176176
*/
177177
public static int sortKeyPrefixArray(
178178
LongArray array,
179-
int startIndex,
180-
int numRecords,
181-
int startByteIndex,
182-
int endByteIndex,
179+
long startIndex,
180+
long numRecords,
181+
long startByteIndex,
182+
long endByteIndex,
183183
boolean desc,
184184
boolean signed) {
185185
assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
186186
assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
187187
assert endByteIndex > startByteIndex;
188188
assert numRecords * 4 <= array.size();
189-
int inIndex = startIndex;
190-
int outIndex = startIndex + numRecords * 2;
189+
long inIndex = startIndex;
190+
long outIndex = startIndex + numRecords * 2L;
191191
if (numRecords > 0) {
192192
long[][] counts = getKeyPrefixArrayCounts(
193193
array, startIndex, numRecords, startByteIndex, endByteIndex);
194-
for (int i = startByteIndex; i <= endByteIndex; i++) {
195-
if (counts[i] != null) {
194+
for (long i = startByteIndex; i <= endByteIndex; i++) {
195+
if (counts[(int)i] != null) {
196196
sortKeyPrefixArrayAtByte(
197-
array, numRecords, counts[i], i, inIndex, outIndex,
197+
array, numRecords, counts[(int)i], i, inIndex, outIndex,
198198
desc, signed && i == endByteIndex);
199-
int tmp = inIndex;
199+
long tmp = inIndex;
200200
inIndex = outIndex;
201201
outIndex = tmp;
202202
}
203203
}
204204
}
205-
return inIndex;
205+
return (int)inIndex;
206206
}
207207

208208
/**
209209
* Specialization of getCounts() for key-prefix arrays. We could probably combine this with
210210
* getCounts with some added parameters but that seems to hurt in benchmarks.
211211
*/
212212
private static long[][] getKeyPrefixArrayCounts(
213-
LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) {
213+
LongArray array, long startIndex, long numRecords, long startByteIndex, long endByteIndex) {
214214
long[][] counts = new long[8][];
215215
long bitwiseMax = 0;
216216
long bitwiseMin = -1L;
@@ -223,11 +223,11 @@ private static long[][] getKeyPrefixArrayCounts(
223223
bitwiseMin &= value;
224224
}
225225
long bitsChanged = bitwiseMin ^ bitwiseMax;
226-
for (int i = startByteIndex; i <= endByteIndex; i++) {
226+
for (long i = startByteIndex; i <= endByteIndex; i++) {
227227
if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
228-
counts[i] = new long[256];
228+
counts[(int)i] = new long[256];
229229
for (long offset = baseOffset; offset < limit; offset += 16) {
230-
counts[i][(int)((Platform.getLong(baseObject, offset + 8) >>> (i * 8)) & 0xff)]++;
230+
counts[(int)i][(int)((Platform.getLong(baseObject, offset + 8) >>> (i * 8)) & 0xff)]++;
231231
}
232232
}
233233
}
@@ -238,7 +238,7 @@ private static long[][] getKeyPrefixArrayCounts(
238238
* Specialization of sortAtByte() for key-prefix arrays.
239239
*/
240240
private static void sortKeyPrefixArrayAtByte(
241-
LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
241+
LongArray array, long numRecords, long[] counts, long byteIdx, long inIndex, long outIndex,
242242
boolean desc, boolean signed) {
243243
assert counts.length == 256;
244244
long[] offsets = transformCountsToOffsets(

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ public UnsafeSorterIterator getSortedIterator() {
322322
if (sortComparator != null) {
323323
if (this.radixSortSupport != null) {
324324
offset = RadixSort.sortKeyPrefixArray(
325-
array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7,
325+
array, (long)nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0L, 7L,
326326
radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
327327
} else {
328328
MemoryBlock unused = new MemoryBlock(

core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.util.collection.Sorter
3030
import org.apache.spark.util.random.XORShiftRandom
3131

3232
class RadixSortSuite extends SparkFunSuite with Logging {
33-
private val N = 10000 // scale this down for more readable results
33+
private val N = 10000.toLong // scale this down for more readable results
3434

3535
/**
3636
* Describes a type of sort to test, e.g. two's complement descending. Each sort type has
@@ -40,7 +40,7 @@ class RadixSortSuite extends SparkFunSuite with Logging {
4040
case class RadixSortType(
4141
name: String,
4242
referenceComparator: PrefixComparator,
43-
startByteIdx: Int, endByteIdx: Int, descending: Boolean, signed: Boolean, nullsFirst: Boolean)
43+
startByteIdx: Long, endByteIdx: Long, descending: Boolean, signed: Boolean, nullsFirst: Boolean)
4444

4545
val SORT_TYPES_TO_TEST = Seq(
4646
RadixSortType("unsigned binary data asc nulls first",
@@ -73,22 +73,22 @@ class RadixSortSuite extends SparkFunSuite with Logging {
7373
},
7474
2, 4, false, false, true))
7575

76-
private def generateTestData(size: Int, rand: => Long): (Array[JLong], LongArray) = {
77-
val ref = Array.tabulate[Long](size) { i => rand }
78-
val extended = ref ++ Array.fill[Long](size)(0)
76+
private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = {
77+
val ref = Array.tabulate[Long](size.toInt) { i => rand }
78+
val extended = ref ++ Array.fill[Long](size.toInt)(0)
7979
(ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended)))
8080
}
8181

82-
private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = {
83-
val ref = Array.tabulate[Long](size * 2) { i => rand }
84-
val extended = ref ++ Array.fill[Long](size * 2)(0)
82+
private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = {
83+
val ref = Array.tabulate[Long]((size * 2).toInt) { i => rand }
84+
val extended = ref ++ Array.fill[Long]((size * 2).toInt)(0)
8585
(new LongArray(MemoryBlock.fromLongArray(ref)),
8686
new LongArray(MemoryBlock.fromLongArray(extended)))
8787
}
8888

89-
private def collectToArray(array: LongArray, offset: Int, length: Int): Array[Long] = {
89+
private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = {
9090
var i = 0
91-
val out = new Array[Long](length)
91+
val out = new Array[Long](length.toInt)
9292
while (i < length) {
9393
out(i) = array.get(offset + i)
9494
i += 1
@@ -107,10 +107,10 @@ class RadixSortSuite extends SparkFunSuite with Logging {
107107
}
108108
}
109109

110-
private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) {
110+
private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) {
111111
val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
112112
new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
113-
buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] {
113+
buf, lo.toInt, hi.toInt, new Comparator[RecordPointerAndKeyPrefix] {
114114
override def compare(
115115
r1: RecordPointerAndKeyPrefix,
116116
r2: RecordPointerAndKeyPrefix): Int = {
@@ -156,7 +156,7 @@ class RadixSortSuite extends SparkFunSuite with Logging {
156156
val (ref, buffer) = generateTestData(N, rand.nextLong)
157157
Arrays.sort(ref, toJavaComparator(sortType.referenceComparator))
158158
val outOffset = RadixSort.sort(
159-
buffer, N, sortType.startByteIdx, sortType.endByteIdx,
159+
buffer, N.toLong, sortType.startByteIdx, sortType.endByteIdx,
160160
sortType.descending, sortType.signed)
161161
val result = collectToArray(buffer, outOffset, N)
162162
assert(ref.view == result.view)

sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class SortBenchmark extends BenchmarkBase {
8080
}
8181
val buf = new LongArray(MemoryBlock.fromLongArray(array))
8282
timer.startTiming()
83-
RadixSort.sort(buf, size, 0, 7, false, false)
83+
RadixSort.sort(buf, size.toLong, 0.toLong, 7.toLong, false, false)
8484
timer.stopTiming()
8585
}
8686
benchmark.addTimerCase("radix sort two bytes") { timer =>
@@ -92,7 +92,7 @@ class SortBenchmark extends BenchmarkBase {
9292
}
9393
val buf = new LongArray(MemoryBlock.fromLongArray(array))
9494
timer.startTiming()
95-
RadixSort.sort(buf, size, 0, 7, false, false)
95+
RadixSort.sort(buf, size.toLong, 0.toLong, 7.toLong, false, false)
9696
timer.stopTiming()
9797
}
9898
benchmark.addTimerCase("radix sort eight bytes") { timer =>
@@ -104,13 +104,13 @@ class SortBenchmark extends BenchmarkBase {
104104
}
105105
val buf = new LongArray(MemoryBlock.fromLongArray(array))
106106
timer.startTiming()
107-
RadixSort.sort(buf, size, 0, 7, false, false)
107+
RadixSort.sort(buf, size.toLong, 0.toLong, 7.toLong, false, false)
108108
timer.stopTiming()
109109
}
110110
benchmark.addTimerCase("radix sort key prefix array") { timer =>
111111
val (_, buf2) = generateKeyPrefixTestData(size, rand.nextLong)
112112
timer.startTiming()
113-
RadixSort.sortKeyPrefixArray(buf2, 0, size, 0, 7, false, false)
113+
RadixSort.sortKeyPrefixArray(buf2, 0.toLong, size.toLong, 0.toLong, 7.toLong, false, false)
114114
timer.stopTiming()
115115
}
116116
benchmark.run()

0 commit comments

Comments
 (0)