Skip to content

Commit f480fb2

Browse files
committed
WIP in mega-refactoring towards shuffle-specific sort.
1 parent 57f1ec0 commit f480fb2

16 files changed

+497
-1074
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.shuffle.unsafe;
19+
20+
import org.apache.spark.serializer.DeserializationStream;
21+
import org.apache.spark.serializer.SerializationStream;
22+
import org.apache.spark.serializer.SerializerInstance;
23+
import scala.reflect.ClassTag;
24+
25+
import java.io.InputStream;
26+
import java.io.OutputStream;
27+
import java.nio.ByteBuffer;
28+
29+
class DummySerializerInstance extends SerializerInstance {
30+
@Override
31+
public SerializationStream serializeStream(OutputStream s) {
32+
return new SerializationStream() {
33+
@Override
34+
public void flush() {
35+
36+
}
37+
38+
@Override
39+
public <T> SerializationStream writeObject(T t, ClassTag<T> ev1) {
40+
return null;
41+
}
42+
43+
@Override
44+
public void close() {
45+
46+
}
47+
};
48+
}
49+
50+
@Override
51+
public <T> ByteBuffer serialize(T t, ClassTag<T> ev1) {
52+
return null;
53+
}
54+
55+
@Override
56+
public DeserializationStream deserializeStream(InputStream s) {
57+
return null;
58+
}
59+
60+
@Override
61+
public <T> T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag<T> ev1) {
62+
return null;
63+
}
64+
65+
@Override
66+
public <T> T deserialize(ByteBuffer bytes, ClassTag<T> ev1) {
67+
return null;
68+
}
69+
}

core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterIterator.java renamed to core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,20 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.unsafe.sort;
18+
package org.apache.spark.shuffle.unsafe;
1919

20-
import java.io.IOException;
20+
import org.apache.spark.storage.BlockId;
2121

22-
public abstract class UnsafeSorterIterator {
22+
import java.io.File;
2323

24-
public abstract boolean hasNext();
24+
final class SpillInfo {
25+
final long[] partitionLengths;
26+
final File file;
27+
final BlockId blockId;
2528

26-
public abstract void loadNext() throws IOException;
27-
28-
public abstract Object getBaseObject();
29-
30-
public abstract long getBaseOffset();
31-
32-
public abstract int getRecordLength();
33-
34-
public abstract long getKeyPrefix();
29+
public SpillInfo(int numPartitions, File file, BlockId blockId) {
30+
this.partitionLengths = new long[numPartitions];
31+
this.file = file;
32+
this.blockId = blockId;
33+
}
3534
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.shuffle.unsafe;
19+
20+
import org.apache.spark.util.collection.SortDataFormat;
21+
22+
final class UnsafeShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
23+
24+
public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat();
25+
26+
private UnsafeShuffleSortDataFormat() { }
27+
28+
@Override
29+
public PackedRecordPointer getKey(long[] data, int pos) {
30+
// Since we re-use keys, this method shouldn't be called.
31+
throw new UnsupportedOperationException();
32+
}
33+
34+
@Override
35+
public PackedRecordPointer newKey() {
36+
return new PackedRecordPointer();
37+
}
38+
39+
@Override
40+
public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) {
41+
reuse.packedRecordPointer = data[pos];
42+
return reuse;
43+
}
44+
45+
@Override
46+
public void swap(long[] data, int pos0, int pos1) {
47+
final long temp = data[pos0];
48+
data[pos0] = data[pos1];
49+
data[pos1] = temp;
50+
}
51+
52+
@Override
53+
public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
54+
dst[dstPos] = src[srcPos];
55+
}
56+
57+
@Override
58+
public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
59+
System.arraycopy(src, srcPos, dst, dstPos, length);
60+
}
61+
62+
@Override
63+
public long[] allocate(int length) {
64+
assert (length < Integer.MAX_VALUE) : "Length " + length + " is too large";
65+
return new long[length];
66+
}
67+
68+
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.shuffle.unsafe;
19+
20+
import java.util.Comparator;
21+
22+
import org.apache.spark.util.collection.Sorter;
23+
24+
public final class UnsafeShuffleSorter {
25+
26+
private final Sorter<PackedRecordPointer, long[]> sorter;
27+
private final Comparator<PackedRecordPointer> sortComparator;
28+
29+
private long[] sortBuffer;
30+
31+
/**
32+
* The position in the sort buffer where new records can be inserted.
33+
*/
34+
private int sortBufferInsertPosition = 0;
35+
36+
public UnsafeShuffleSorter(int initialSize) {
37+
assert (initialSize > 0);
38+
this.sortBuffer = new long[initialSize];
39+
this.sorter =
40+
new Sorter<PackedRecordPointer, long[]>(UnsafeShuffleSortDataFormat.INSTANCE);
41+
this.sortComparator = new Comparator<PackedRecordPointer>() {
42+
@Override
43+
public int compare(PackedRecordPointer left, PackedRecordPointer right) {
44+
return left.getPartitionId() - right.getPartitionId();
45+
}
46+
};
47+
}
48+
49+
public void expandSortBuffer() {
50+
final long[] oldBuffer = sortBuffer;
51+
sortBuffer = new long[oldBuffer.length * 2];
52+
System.arraycopy(oldBuffer, 0, sortBuffer, 0, oldBuffer.length);
53+
}
54+
55+
public boolean hasSpaceForAnotherRecord() {
56+
return sortBufferInsertPosition + 1 < sortBuffer.length;
57+
}
58+
59+
public long getMemoryUsage() {
60+
return sortBuffer.length * 8L;
61+
}
62+
63+
// TODO: clairify assumption that pointer points to record length.
64+
public void insertRecord(long recordPointer, int partitionId) {
65+
if (!hasSpaceForAnotherRecord()) {
66+
expandSortBuffer();
67+
}
68+
sortBuffer[sortBufferInsertPosition] =
69+
PackedRecordPointer.packPointer(recordPointer, partitionId);
70+
sortBufferInsertPosition++;
71+
}
72+
73+
public static abstract class UnsafeShuffleSorterIterator {
74+
75+
final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
76+
77+
public abstract boolean hasNext();
78+
79+
public abstract void loadNext();
80+
81+
}
82+
83+
/**
84+
* Return an iterator over record pointers in sorted order. For efficiency, all calls to
85+
* {@code next()} will return the same mutable object.
86+
*/
87+
public UnsafeShuffleSorterIterator getSortedIterator() {
88+
sorter.sort(sortBuffer, 0, sortBufferInsertPosition, sortComparator);
89+
return new UnsafeShuffleSorterIterator() {
90+
91+
private int position = 0;
92+
93+
@Override
94+
public boolean hasNext() {
95+
return position < sortBufferInsertPosition;
96+
}
97+
98+
@Override
99+
public void loadNext() {
100+
packedRecordPointer.packedRecordPointer = sortBuffer[position];
101+
position++;
102+
}
103+
};
104+
}
105+
}

0 commit comments

Comments
 (0)