Skip to content

Commit 7fd964e

Browse files
committed
make RoundRobinPartitioning output deterministic.
1 parent 45b4bbf commit 7fd964e

File tree

16 files changed

+231
-29
lines changed

16 files changed

+231
-29
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ public abstract class RecordComparator {
3232
public abstract int compare(
3333
Object leftBaseObject,
3434
long leftBaseOffset,
35+
int leftBaseLength,
3536
Object rightBaseObject,
36-
long rightBaseOffset);
37+
long rightBaseOffset,
38+
int rightBaseLength);
3739
}

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,13 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {
6262
int uaoSize = UnsafeAlignedOffset.getUaoSize();
6363
if (prefixComparisonResult == 0) {
6464
final Object baseObject1 = memoryManager.getPage(r1.recordPointer);
65-
// skip length
6665
final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + uaoSize;
66+
final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, baseOffset1 - uaoSize);
6767
final Object baseObject2 = memoryManager.getPage(r2.recordPointer);
68-
// skip length
6968
final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + uaoSize;
70-
return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2);
69+
final int baseLength2 = UnsafeAlignedOffset.getSize(baseObject2, baseOffset2 - uaoSize);
70+
return recordComparator.compare(baseObject1, baseOffset1, baseLength1, baseObject2,
71+
baseOffset2, baseLength2);
7172
} else {
7273
return prefixComparisonResult;
7374
}

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ final class UnsafeSorterSpillMerger {
3535
prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix());
3636
if (prefixComparisonResult == 0) {
3737
return recordComparator.compare(
38-
left.getBaseObject(), left.getBaseOffset(),
39-
right.getBaseObject(), right.getBaseOffset());
38+
left.getBaseObject(), left.getBaseOffset(), left.getRecordLength(),
39+
right.getBaseObject(), right.getBaseOffset(), right.getRecordLength());
4040
} else {
4141
return prefixComparisonResult;
4242
}

core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ public class UnsafeExternalSorterSuite {
7272
public int compare(
7373
Object leftBaseObject,
7474
long leftBaseOffset,
75+
int leftBaseLength,
7576
Object rightBaseObject,
76-
long rightBaseOffset) {
77+
long rightBaseOffset,
78+
int rightBaseLength) {
7779
return 0;
7880
}
7981
};

core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,10 @@ public void testSortingOnlyByIntegerPrefix() throws Exception {
9898
public int compare(
9999
Object leftBaseObject,
100100
long leftBaseOffset,
101+
int leftBaseLength,
101102
Object rightBaseObject,
102-
long rightBaseOffset) {
103+
long rightBaseOffset,
104+
int rightBaseLength) {
103105
return 0;
104106
}
105107
};
@@ -164,8 +166,10 @@ public void freeAfterOOM() {
164166
public int compare(
165167
Object leftBaseObject,
166168
long leftBaseOffset,
169+
int leftBaseLength,
167170
Object rightBaseObject,
168-
long rightBaseOffset) {
171+
long rightBaseOffset,
172+
int rightBaseLength) {
169173
return 0;
170174
}
171175
};

mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
222222
val oldModel = new OldWord2VecModel(word2VecMap)
223223
val instance = new Word2VecModel("myWord2VecModel", oldModel)
224224
val newInstance = testDefaultReadWrite(instance)
225-
assert(newInstance.getVectors.collect() === instance.getVectors.collect())
225+
assert(newInstance.getVectors.collect().sortBy(_.getString(0)) ===
226+
instance.getVectors.collect().sortBy(_.getString(0)))
226227
}
227228

228229
test("Word2Vec works with input that is non-nullable (NGram)") {
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution;
19+
20+
import org.apache.spark.unsafe.Platform;
21+
import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
22+
23+
public final class RecordBinaryComparator extends RecordComparator {
24+
25+
// TODO(jiangxb) Add test suite for this.
26+
@Override
27+
public int compare(
28+
Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) {
29+
int i = 0;
30+
int res = 0;
31+
32+
// If the arrays have different length, the longer one is larger.
33+
if (leftLen != rightLen) {
34+
return leftLen - rightLen;
35+
}
36+
37+
// The following logic uses `leftLen` as the length for both `leftObj` and `rightObj`, since
38+
// we have guaranteed `leftLen` == `rightLen`.
39+
40+
// check if stars align and we can get both offsets to be aligned
41+
if ((leftOff % 8) == (rightOff % 8)) {
42+
while ((leftOff + i) % 8 != 0 && i < leftLen) {
43+
res = (Platform.getByte(leftObj, leftOff + i) & 0xff) -
44+
(Platform.getByte(rightObj, rightOff + i) & 0xff);
45+
if (res != 0) return res;
46+
i += 1;
47+
}
48+
}
49+
// for architectures that support unaligned accesses, chew it up 8 bytes at a time
50+
if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) {
51+
while (i <= leftLen - 8) {
52+
res = (int) ((Platform.getLong(leftObj, leftOff + i) -
53+
Platform.getLong(rightObj, rightOff + i)) % Integer.MAX_VALUE);
54+
if (res != 0) return res;
55+
i += 8;
56+
}
57+
}
58+
// this will finish off the unaligned comparisons, or do the entire aligned comparison
59+
// whichever is needed.
60+
while (i < leftLen) {
61+
res = (Platform.getByte(leftObj, leftOff + i) & 0xff) -
62+
(Platform.getByte(rightObj, rightOff + i) & 0xff);
63+
if (res != 0) return res;
64+
i += 1;
65+
}
66+
67+
// The two arrays are equal.
68+
return 0;
69+
}
70+
}

sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
package org.apache.spark.sql.execution;
1919

2020
import java.io.IOException;
21+
import java.util.function.Supplier;
2122

23+
import org.apache.spark.sql.catalyst.util.TypeUtils;
2224
import scala.collection.AbstractIterator;
2325
import scala.collection.Iterator;
2426
import scala.math.Ordering;
@@ -56,26 +58,50 @@ public abstract static class PrefixComputer {
5658

5759
public static class Prefix {
5860
/** Key prefix value, or the null prefix value if isNull = true. **/
59-
long value;
61+
public long value;
6062

6163
/** Whether the key is null. */
62-
boolean isNull;
64+
public boolean isNull;
6365
}
6466

6567
/**
6668
* Computes prefix for the given row. For efficiency, the returned object may be reused in
6769
* further calls to a given PrefixComputer.
6870
*/
69-
abstract Prefix computePrefix(InternalRow row);
71+
public abstract Prefix computePrefix(InternalRow row);
7072
}
7173

72-
public UnsafeExternalRowSorter(
74+
public static UnsafeExternalRowSorter createWithRecordComparator(
75+
StructType schema,
76+
Supplier<RecordComparator> recordComparatorSupplier,
77+
PrefixComparator prefixComparator,
78+
PrefixComputer prefixComputer,
79+
long pageSizeBytes,
80+
boolean canUseRadixSort) throws IOException {
81+
return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator,
82+
prefixComputer, pageSizeBytes, canUseRadixSort);
83+
}
84+
85+
public static UnsafeExternalRowSorter create(
7386
StructType schema,
7487
Ordering<InternalRow> ordering,
7588
PrefixComparator prefixComparator,
7689
PrefixComputer prefixComputer,
7790
long pageSizeBytes,
7891
boolean canUseRadixSort) throws IOException {
92+
Supplier<RecordComparator> recordComparatorSupplier =
93+
() -> new RowComparator(ordering, schema.length());
94+
return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator,
95+
prefixComputer, pageSizeBytes, canUseRadixSort);
96+
}
97+
98+
private UnsafeExternalRowSorter(
99+
StructType schema,
100+
Supplier<RecordComparator> recordComparatorSupplier,
101+
PrefixComparator prefixComparator,
102+
PrefixComputer prefixComputer,
103+
long pageSizeBytes,
104+
boolean canUseRadixSort) throws IOException {
79105
this.schema = schema;
80106
this.prefixComputer = prefixComputer;
81107
final SparkEnv sparkEnv = SparkEnv.get();
@@ -85,7 +111,7 @@ public UnsafeExternalRowSorter(
85111
sparkEnv.blockManager(),
86112
sparkEnv.serializerManager(),
87113
taskContext,
88-
() -> new RowComparator(ordering, schema.length()),
114+
recordComparatorSupplier,
89115
prefixComparator,
90116
sparkEnv.conf().getInt("spark.shuffle.sort.initialBufferSize",
91117
DEFAULT_INITIAL_SORT_BUFFER_SIZE),
@@ -206,7 +232,13 @@ private static final class RowComparator extends RecordComparator {
206232
}
207233

208234
@Override
209-
public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
235+
public int compare(
236+
Object baseObj1,
237+
long baseOff1,
238+
int baseLen1,
239+
Object baseObj2,
240+
long baseOff2,
241+
int baseLen2) {
210242
// Note that since ordering doesn't need the total length of the record, we just pass 0
211243
// into the row.
212244
row1.pointTo(baseObj1, baseOff1, 0);

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,18 @@ object SQLConf {
11281128
.checkValues(PartitionOverwriteMode.values.map(_.toString))
11291129
.createWithDefault(PartitionOverwriteMode.STATIC.toString)
11301130

1131+
val SORT_BEFORE_REPARTITION =
1132+
buildConf("spark.sql.execution.sortBeforeRepartition")
1133+
.internal()
1134+
.doc("When perform a repartition following a shuffle, the output row ordering would be " +
1135+
"nondeterministic. If some downstream stages fail and some tasks of the repartition " +
1136+
"stage retry, these tasks may generate different data, and that can lead to correctness " +
1137+
"issues. Turn on this config to insert a local sort before actually doing repartition " +
1138+
"to generate consistent repartition results. The performance of repartition() may go " +
1139+
"down since we insert extra local sort before it.")
1140+
.booleanConf
1141+
.createWithDefault(true)
1142+
11311143
object Deprecated {
11321144
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
11331145
}
@@ -1278,6 +1290,8 @@ class SQLConf extends Serializable with Logging {
12781290

12791291
def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader)
12801292

1293+
def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION)
1294+
12811295
/**
12821296
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
12831297
* identifiers are equal.

sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,13 @@ private static final class KVComparator extends RecordComparator {
241241
}
242242

243243
@Override
244-
public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
244+
public int compare(
245+
Object baseObj1,
246+
long baseOff1,
247+
int baseLen1,
248+
Object baseObj2,
249+
long baseOff2,
250+
int baseLen2) {
245251
// Note that since ordering doesn't need the total length of the record, we just pass 0
246252
// into the row.
247253
row1.pointTo(baseObj1, baseOff1 + 4, 0);

0 commit comments

Comments
 (0)