20
20
import java .util .Comparator ;
21
21
22
22
/**
23
- * A port of the Android Timsort class, which utilizes a "stable, adaptive, iterative mergesort."
23
+ * A port of the Android TimSort class, which utilizes a "stable, adaptive, iterative mergesort."
24
24
* See the method comment on sort() for more details.
25
25
*
26
26
* This has been kept in Java with the original style in order to match very closely with the
27
- * Anroid source code, and thus be easy to verify correctness.
27
+ * Android source code, and thus be easy to verify correctness. The class is package private. We put
28
+ * a simple Scala wrapper {@link org.apache.spark.util.collection.Sorter}, which is available to
29
+ * package org.apache.spark.
28
30
*
29
31
* The purpose of the port is to generalize the interface to the sort to accept input data formats
30
32
* besides simple arrays where every element is sorted individually. For instance, the AppendOnlyMap
31
33
* uses this to sort an Array with alternating elements of the form [key, value, key, value].
32
34
* This generalization comes with minimal overhead -- see SortDataFormat for more information.
35
+ *
36
+ * We allow key reuse to prevent creating many key objects -- see SortDataFormat.
37
+ *
38
+ * @see org.apache.spark.util.collection.SortDataFormat
39
+ * @see org.apache.spark.util.collection.Sorter
33
40
*/
34
- class Sorter <K , Buffer > {
41
+ class TimSort <K , Buffer > {
35
42
36
43
/**
37
44
* This is the minimum sized sequence that will be merged. Shorter
@@ -54,7 +61,7 @@ class Sorter<K, Buffer> {
54
61
55
62
private final SortDataFormat <K , Buffer > s ;
56
63
57
- public Sorter (SortDataFormat <K , Buffer > sortDataFormat ) {
64
+ public TimSort (SortDataFormat <K , Buffer > sortDataFormat ) {
58
65
this .s = sortDataFormat ;
59
66
}
60
67
@@ -91,7 +98,7 @@ public Sorter(SortDataFormat<K, Buffer> sortDataFormat) {
91
98
*
92
99
* @author Josh Bloch
93
100
*/
94
- void sort (Buffer a , int lo , int hi , Comparator <? super K > c ) {
101
+ public void sort (Buffer a , int lo , int hi , Comparator <? super K > c ) {
95
102
assert c != null ;
96
103
97
104
int nRemaining = hi - lo ;
@@ -162,10 +169,13 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator<? super
162
169
if (start == lo )
163
170
start ++;
164
171
172
+ K key0 = s .newKey ();
173
+ K key1 = s .newKey ();
174
+
165
175
Buffer pivotStore = s .allocate (1 );
166
176
for ( ; start < hi ; start ++) {
167
177
s .copyElement (a , start , pivotStore , 0 );
168
- K pivot = s .getKey (pivotStore , 0 );
178
+ K pivot = s .getKey (pivotStore , 0 , key0 );
169
179
170
180
// Set left (and right) to the index where a[start] (pivot) belongs
171
181
int left = lo ;
@@ -178,7 +188,7 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator<? super
178
188
*/
179
189
while (left < right ) {
180
190
int mid = (left + right ) >>> 1 ;
181
- if (c .compare (pivot , s .getKey (a , mid )) < 0 )
191
+ if (c .compare (pivot , s .getKey (a , mid , key1 )) < 0 )
182
192
right = mid ;
183
193
else
184
194
left = mid + 1 ;
@@ -235,13 +245,16 @@ private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator<? supe
235
245
if (runHi == hi )
236
246
return 1 ;
237
247
248
+ K key0 = s .newKey ();
249
+ K key1 = s .newKey ();
250
+
238
251
// Find end of run, and reverse range if descending
239
- if (c .compare (s .getKey (a , runHi ++), s .getKey (a , lo )) < 0 ) { // Descending
240
- while (runHi < hi && c .compare (s .getKey (a , runHi ), s .getKey (a , runHi - 1 )) < 0 )
252
+ if (c .compare (s .getKey (a , runHi ++, key0 ), s .getKey (a , lo , key1 )) < 0 ) { // Descending
253
+ while (runHi < hi && c .compare (s .getKey (a , runHi , key0 ), s .getKey (a , runHi - 1 , key1 )) < 0 )
241
254
runHi ++;
242
255
reverseRange (a , lo , runHi );
243
256
} else { // Ascending
244
- while (runHi < hi && c .compare (s .getKey (a , runHi ), s .getKey (a , runHi - 1 )) >= 0 )
257
+ while (runHi < hi && c .compare (s .getKey (a , runHi , key0 ), s .getKey (a , runHi - 1 , key1 )) >= 0 )
245
258
runHi ++;
246
259
}
247
260
@@ -468,11 +481,13 @@ private void mergeAt(int i) {
468
481
}
469
482
stackSize --;
470
483
484
+ K key0 = s .newKey ();
485
+
471
486
/*
472
487
* Find where the first element of run2 goes in run1. Prior elements
473
488
* in run1 can be ignored (because they're already in place).
474
489
*/
475
- int k = gallopRight (s .getKey (a , base2 ), a , base1 , len1 , 0 , c );
490
+ int k = gallopRight (s .getKey (a , base2 , key0 ), a , base1 , len1 , 0 , c );
476
491
assert k >= 0 ;
477
492
base1 += k ;
478
493
len1 -= k ;
@@ -483,7 +498,7 @@ private void mergeAt(int i) {
483
498
* Find where the last element of run1 goes in run2. Subsequent elements
484
499
* in run2 can be ignored (because they're already in place).
485
500
*/
486
- len2 = gallopLeft (s .getKey (a , base1 + len1 - 1 ), a , base2 , len2 , len2 - 1 , c );
501
+ len2 = gallopLeft (s .getKey (a , base1 + len1 - 1 , key0 ), a , base2 , len2 , len2 - 1 , c );
487
502
assert len2 >= 0 ;
488
503
if (len2 == 0 )
489
504
return ;
@@ -517,10 +532,12 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator<
517
532
assert len > 0 && hint >= 0 && hint < len ;
518
533
int lastOfs = 0 ;
519
534
int ofs = 1 ;
520
- if (c .compare (key , s .getKey (a , base + hint )) > 0 ) {
535
+ K key0 = s .newKey ();
536
+
537
+ if (c .compare (key , s .getKey (a , base + hint , key0 )) > 0 ) {
521
538
// Gallop right until a[base+hint+lastOfs] < key <= a[base+hint+ofs]
522
539
int maxOfs = len - hint ;
523
- while (ofs < maxOfs && c .compare (key , s .getKey (a , base + hint + ofs )) > 0 ) {
540
+ while (ofs < maxOfs && c .compare (key , s .getKey (a , base + hint + ofs , key0 )) > 0 ) {
524
541
lastOfs = ofs ;
525
542
ofs = (ofs << 1 ) + 1 ;
526
543
if (ofs <= 0 ) // int overflow
@@ -535,7 +552,7 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator<
535
552
} else { // key <= a[base + hint]
536
553
// Gallop left until a[base+hint-ofs] < key <= a[base+hint-lastOfs]
537
554
final int maxOfs = hint + 1 ;
538
- while (ofs < maxOfs && c .compare (key , s .getKey (a , base + hint - ofs )) <= 0 ) {
555
+ while (ofs < maxOfs && c .compare (key , s .getKey (a , base + hint - ofs , key0 )) <= 0 ) {
539
556
lastOfs = ofs ;
540
557
ofs = (ofs << 1 ) + 1 ;
541
558
if (ofs <= 0 ) // int overflow
@@ -560,7 +577,7 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator<
560
577
while (lastOfs < ofs ) {
561
578
int m = lastOfs + ((ofs - lastOfs ) >>> 1 );
562
579
563
- if (c .compare (key , s .getKey (a , base + m )) > 0 )
580
+ if (c .compare (key , s .getKey (a , base + m , key0 )) > 0 )
564
581
lastOfs = m + 1 ; // a[base + m] < key
565
582
else
566
583
ofs = m ; // key <= a[base + m]
@@ -587,10 +604,12 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator
587
604
588
605
int ofs = 1 ;
589
606
int lastOfs = 0 ;
590
- if (c .compare (key , s .getKey (a , base + hint )) < 0 ) {
607
+ K key1 = s .newKey ();
608
+
609
+ if (c .compare (key , s .getKey (a , base + hint , key1 )) < 0 ) {
591
610
// Gallop left until a[b+hint - ofs] <= key < a[b+hint - lastOfs]
592
611
int maxOfs = hint + 1 ;
593
- while (ofs < maxOfs && c .compare (key , s .getKey (a , base + hint - ofs )) < 0 ) {
612
+ while (ofs < maxOfs && c .compare (key , s .getKey (a , base + hint - ofs , key1 )) < 0 ) {
594
613
lastOfs = ofs ;
595
614
ofs = (ofs << 1 ) + 1 ;
596
615
if (ofs <= 0 ) // int overflow
@@ -606,7 +625,7 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator
606
625
} else { // a[b + hint] <= key
607
626
// Gallop right until a[b+hint + lastOfs] <= key < a[b+hint + ofs]
608
627
int maxOfs = len - hint ;
609
- while (ofs < maxOfs && c .compare (key , s .getKey (a , base + hint + ofs )) >= 0 ) {
628
+ while (ofs < maxOfs && c .compare (key , s .getKey (a , base + hint + ofs , key1 )) >= 0 ) {
610
629
lastOfs = ofs ;
611
630
ofs = (ofs << 1 ) + 1 ;
612
631
if (ofs <= 0 ) // int overflow
@@ -630,7 +649,7 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator
630
649
while (lastOfs < ofs ) {
631
650
int m = lastOfs + ((ofs - lastOfs ) >>> 1 );
632
651
633
- if (c .compare (key , s .getKey (a , base + m )) < 0 )
652
+ if (c .compare (key , s .getKey (a , base + m , key1 )) < 0 )
634
653
ofs = m ; // key < a[b + m]
635
654
else
636
655
lastOfs = m + 1 ; // a[b + m] <= key
@@ -679,6 +698,9 @@ private void mergeLo(int base1, int len1, int base2, int len2) {
679
698
return ;
680
699
}
681
700
701
+ K key0 = s .newKey ();
702
+ K key1 = s .newKey ();
703
+
682
704
Comparator <? super K > c = this .c ; // Use local variable for performance
683
705
int minGallop = this .minGallop ; // " " " " "
684
706
outer :
@@ -692,7 +714,7 @@ private void mergeLo(int base1, int len1, int base2, int len2) {
692
714
*/
693
715
do {
694
716
assert len1 > 1 && len2 > 0 ;
695
- if (c .compare (s .getKey (a , cursor2 ), s .getKey (tmp , cursor1 )) < 0 ) {
717
+ if (c .compare (s .getKey (a , cursor2 , key0 ), s .getKey (tmp , cursor1 , key1 )) < 0 ) {
696
718
s .copyElement (a , cursor2 ++, a , dest ++);
697
719
count2 ++;
698
720
count1 = 0 ;
@@ -714,7 +736,7 @@ private void mergeLo(int base1, int len1, int base2, int len2) {
714
736
*/
715
737
do {
716
738
assert len1 > 1 && len2 > 0 ;
717
- count1 = gallopRight (s .getKey (a , cursor2 ), tmp , cursor1 , len1 , 0 , c );
739
+ count1 = gallopRight (s .getKey (a , cursor2 , key0 ), tmp , cursor1 , len1 , 0 , c );
718
740
if (count1 != 0 ) {
719
741
s .copyRange (tmp , cursor1 , a , dest , count1 );
720
742
dest += count1 ;
@@ -727,7 +749,7 @@ private void mergeLo(int base1, int len1, int base2, int len2) {
727
749
if (--len2 == 0 )
728
750
break outer ;
729
751
730
- count2 = gallopLeft (s .getKey (tmp , cursor1 ), a , cursor2 , len2 , 0 , c );
752
+ count2 = gallopLeft (s .getKey (tmp , cursor1 , key0 ), a , cursor2 , len2 , 0 , c );
731
753
if (count2 != 0 ) {
732
754
s .copyRange (a , cursor2 , a , dest , count2 );
733
755
dest += count2 ;
@@ -784,6 +806,9 @@ private void mergeHi(int base1, int len1, int base2, int len2) {
784
806
int cursor2 = len2 - 1 ; // Indexes into tmp array
785
807
int dest = base2 + len2 - 1 ; // Indexes into a
786
808
809
+ K key0 = s .newKey ();
810
+ K key1 = s .newKey ();
811
+
787
812
// Move last element of first run and deal with degenerate cases
788
813
s .copyElement (a , cursor1 --, a , dest --);
789
814
if (--len1 == 0 ) {
@@ -811,7 +836,7 @@ private void mergeHi(int base1, int len1, int base2, int len2) {
811
836
*/
812
837
do {
813
838
assert len1 > 0 && len2 > 1 ;
814
- if (c .compare (s .getKey (tmp , cursor2 ), s .getKey (a , cursor1 )) < 0 ) {
839
+ if (c .compare (s .getKey (tmp , cursor2 , key0 ), s .getKey (a , cursor1 , key1 )) < 0 ) {
815
840
s .copyElement (a , cursor1 --, a , dest --);
816
841
count1 ++;
817
842
count2 = 0 ;
@@ -833,7 +858,7 @@ private void mergeHi(int base1, int len1, int base2, int len2) {
833
858
*/
834
859
do {
835
860
assert len1 > 0 && len2 > 1 ;
836
- count1 = len1 - gallopRight (s .getKey (tmp , cursor2 ), a , base1 , len1 , len1 - 1 , c );
861
+ count1 = len1 - gallopRight (s .getKey (tmp , cursor2 , key0 ), a , base1 , len1 , len1 - 1 , c );
837
862
if (count1 != 0 ) {
838
863
dest -= count1 ;
839
864
cursor1 -= count1 ;
@@ -846,7 +871,7 @@ private void mergeHi(int base1, int len1, int base2, int len2) {
846
871
if (--len2 == 1 )
847
872
break outer ;
848
873
849
- count2 = len2 - gallopLeft (s .getKey (a , cursor1 ), tmp , 0 , len2 , len2 - 1 , c );
874
+ count2 = len2 - gallopLeft (s .getKey (a , cursor1 , key0 ), tmp , 0 , len2 , len2 - 1 , c );
850
875
if (count2 != 0 ) {
851
876
dest -= count2 ;
852
877
cursor2 -= count2 ;
0 commit comments