Skip to content

Commit d7c3aae

Browse files
jiangxb1987bersprocketszhengruifeng
authored andcommitted
[SPARK-23207][SPARK-22905][SPARK-24564][SPARK-25114][SQL][BACKPORT-2.2] Shuffle+Repartition on a DataFrame could lead to incorrect answers
## What changes were proposed in this pull request? Back port of #20393. Currently shuffle repartition uses RoundRobinPartitioning, the generated result is nondeterministic since the sequence of input rows are not determined. The bug can be triggered when there is a repartition call following a shuffle (which would lead to non-deterministic row ordering), as the pattern shows below: upstream stage -> repartition stage -> result stage (-> indicate a shuffle) When one of the executors process goes down, some tasks on the repartition stage will be retried and generate inconsistent ordering, and some tasks of the result stage will be retried generating different data. The following code returns 931532, instead of 1000000: ``` import scala.sys.process._ import org.apache.spark.TaskContext val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x => x }.repartition(200).map { x => if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 2) { throw new Exception("pkill -f java".!!) } x } res.distinct().count() ``` In this PR, we propose a most straight-forward way to fix this problem by performing a local sort before partitioning, after we make the input row ordering deterministic, the function from rows to partitions is fully deterministic too. The downside of the approach is that with extra local sort inserted, the performance of repartition() will go down, so we add a new config named `spark.sql.execution.sortBeforeRepartition` to control whether this patch is applied. The patch is default enabled to be safe-by-default, but user may choose to manually turn it off to avoid performance regression. This patch also changes the output rows ordering of repartition(), that leads to a bunch of test cases failure because they are comparing the results directly. Add unit test in ExchangeSuite. With this patch(and `spark.sql.execution.sortBeforeRepartition` set to true), the following query returns 1000000: ``` import scala.sys.process._ import org.apache.spark.TaskContext spark.conf.set("spark.sql.execution.sortBeforeRepartition", "true") val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x => x }.repartition(200).map { x => if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 2) { throw new Exception("pkill -f java".!!) } x } res.distinct().count() res7: Long = 1000000 ``` Author: Xingbo Jiang <xingbo.jiangdatabricks.com> ## How was this patch tested? Ran all SBT unit tests for org.apache.spark.sql.*. Ran pyspark tests for module pyspark-sql. Closes #22079 from bersprockets/SPARK-23207. Lead-authored-by: Xingbo Jiang <xingbo.jiang@databricks.com> Co-authored-by: Bruce Robbins <bersprockets@gmail.com> Co-authored-by: Zheng RuiFeng <ruifengz@foxmail.com> Signed-off-by: Xiao Li <gatorsmile@gmail.com>
1 parent 124789b commit d7c3aae

File tree

20 files changed

+575
-30
lines changed

20 files changed

+575
-30
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ public abstract class RecordComparator {
3232
public abstract int compare(
3333
Object leftBaseObject,
3434
long leftBaseOffset,
35+
int leftBaseLength,
3536
Object rightBaseObject,
36-
long rightBaseOffset);
37+
long rightBaseOffset,
38+
int rightBaseLength);
3739
}

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,13 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {
6161
int uaoSize = UnsafeAlignedOffset.getUaoSize();
6262
if (prefixComparisonResult == 0) {
6363
final Object baseObject1 = memoryManager.getPage(r1.recordPointer);
64-
// skip length
6564
final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + uaoSize;
65+
final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, baseOffset1 - uaoSize);
6666
final Object baseObject2 = memoryManager.getPage(r2.recordPointer);
67-
// skip length
6867
final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + uaoSize;
69-
return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2);
68+
final int baseLength2 = UnsafeAlignedOffset.getSize(baseObject2, baseOffset2 - uaoSize);
69+
return recordComparator.compare(baseObject1, baseOffset1, baseLength1, baseObject2,
70+
baseOffset2, baseLength2);
7071
} else {
7172
return prefixComparisonResult;
7273
}

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ final class UnsafeSorterSpillMerger {
3535
prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix());
3636
if (prefixComparisonResult == 0) {
3737
return recordComparator.compare(
38-
left.getBaseObject(), left.getBaseOffset(),
39-
right.getBaseObject(), right.getBaseOffset());
38+
left.getBaseObject(), left.getBaseOffset(), left.getRecordLength(),
39+
right.getBaseObject(), right.getBaseOffset(), right.getRecordLength());
4040
} else {
4141
return prefixComparisonResult;
4242
}

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,8 @@ abstract class RDD[T: ClassTag](
413413
*
414414
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
415415
* which can avoid performing a shuffle.
416+
*
417+
* TODO Fix the Shuffle+Repartition data loss issue described in SPARK-23207.
416418
*/
417419
def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = withScope {
418420
coalesce(numPartitions, shuffle = true)

core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
package org.apache.spark.memory;
1919

20+
import com.google.common.annotations.VisibleForTesting;
21+
22+
import org.apache.spark.unsafe.memory.MemoryBlock;
23+
2024
import java.io.IOException;
2125

2226
public class TestMemoryConsumer extends MemoryConsumer {
@@ -43,6 +47,12 @@ void free(long size) {
4347
used -= size;
4448
taskMemoryManager.releaseExecutionMemory(size, this);
4549
}
50+
51+
@VisibleForTesting
52+
public void freePage(MemoryBlock page) {
53+
used -= page.size();
54+
taskMemoryManager.freePage(page, this);
55+
}
4656
}
4757

4858

core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,10 @@ public class UnsafeExternalSorterSuite {
7171
public int compare(
7272
Object leftBaseObject,
7373
long leftBaseOffset,
74+
int leftBaseLength,
7475
Object rightBaseObject,
75-
long rightBaseOffset) {
76+
long rightBaseOffset,
77+
int rightBaseLength) {
7678
return 0;
7779
}
7880
};

core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,10 @@ public void testSortingOnlyByIntegerPrefix() throws Exception {
9898
public int compare(
9999
Object leftBaseObject,
100100
long leftBaseOffset,
101+
int leftBaseLength,
101102
Object rightBaseObject,
102-
long rightBaseOffset) {
103+
long rightBaseOffset,
104+
int rightBaseLength) {
103105
return 0;
104106
}
105107
};
@@ -164,8 +166,10 @@ public void freeAfterOOM() {
164166
public int compare(
165167
Object leftBaseObject,
166168
long leftBaseOffset,
169+
int leftBaseLength,
167170
Object rightBaseObject,
168-
long rightBaseOffset) {
171+
long rightBaseOffset,
172+
int rightBaseLength) {
169173
return 0;
170174
}
171175
};

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
154154
val dataArray = Array.tabulate(weights.length) { i =>
155155
Data(weights(i), gaussians(i).mu, gaussians(i).sigma)
156156
}
157-
spark.createDataFrame(dataArray).repartition(1).write.parquet(Loader.dataPath(path))
157+
spark.createDataFrame(sc.makeRDD(dataArray, 1)).write.parquet(Loader.dataPath(path))
158158
}
159159

160160
def load(sc: SparkContext, path: String): GaussianMixtureModel = {

mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
144144
val dataArray = Array.tabulate(model.selectedFeatures.length) { i =>
145145
Data(model.selectedFeatures(i))
146146
}
147-
spark.createDataFrame(dataArray).repartition(1).write.parquet(Loader.dataPath(path))
147+
spark.createDataFrame(sc.makeRDD(dataArray, 1)).write.parquet(Loader.dataPath(path))
148148
}
149149

150150
def load(sc: SparkContext, path: String): ChiSqSelectorModel = {

mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
222222
val oldModel = new OldWord2VecModel(word2VecMap)
223223
val instance = new Word2VecModel("myWord2VecModel", oldModel)
224224
val newInstance = testDefaultReadWrite(instance)
225-
assert(newInstance.getVectors.collect() === instance.getVectors.collect())
225+
assert(newInstance.getVectors.collect().sortBy(_.getString(0)) ===
226+
instance.getVectors.collect().sortBy(_.getString(0)))
226227
}
227228

228229
test("Word2Vec works with input that is non-nullable (NGram)") {

0 commit comments

Comments
 (0)