Skip to content

Commit ddf6dd8

Browse files
sameeragarwalcloud-fan
authored andcommitted
[SPARK-20451] Filter out nested mapType datatypes from sort order in randomSplit
## What changes were proposed in this pull request? In `randomSplit`, It is possible that the underlying dataset doesn't guarantee the ordering of rows in its constituent partitions each time a split is materialized which could result in overlapping splits. To prevent this, as part of SPARK-12662, we explicitly sort each input partition to make the ordering deterministic. Given that `MapTypes` cannot be sorted this patch explicitly prunes them out from the sort order. Additionally, if the resulting sort order is empty, this patch then materializes the dataset to guarantee determinism. ## How was this patch tested? Extended `randomSplit on reordered partitions` in `DataFrameStatSuite` to also test for dataframes with mapTypes nested mapTypes. Author: Sameer Agarwal <sameerag@cs.berkeley.edu> Closes #17751 from sameeragarwal/randomsplit2. (cherry picked from commit 31345fd) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 84be4c8 commit ddf6dd8

File tree

2 files changed

+41
-20
lines changed

2 files changed

+41
-20
lines changed

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,15 +1542,23 @@ class Dataset[T] private[sql](
15421542
// It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its
15431543
// constituent partitions each time a split is materialized which could result in
15441544
// overlapping splits. To prevent this, we explicitly sort each input partition to make the
1545-
// ordering deterministic.
1546-
// MapType cannot be sorted.
1547-
val sorted = Sort(logicalPlan.output.filterNot(_.dataType.isInstanceOf[MapType])
1548-
.map(SortOrder(_, Ascending)), global = false, logicalPlan)
1545+
// ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out
1546+
// from the sort order.
1547+
val sortOrder = logicalPlan.output
1548+
.filter(attr => RowOrdering.isOrderable(attr.dataType))
1549+
.map(SortOrder(_, Ascending))
1550+
val plan = if (sortOrder.nonEmpty) {
1551+
Sort(sortOrder, global = false, logicalPlan)
1552+
} else {
1553+
// SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism
1554+
cache()
1555+
logicalPlan
1556+
}
15491557
val sum = weights.sum
15501558
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
15511559
normalizedCumWeights.sliding(2).map { x =>
15521560
new Dataset[T](
1553-
sparkSession, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder)
1561+
sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan)(), encoder)
15541562
}.toArray
15551563
}
15561564

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,25 +68,38 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
6868
}
6969

7070
test("randomSplit on reordered partitions") {
71-
// This test ensures that randomSplit does not create overlapping splits even when the
72-
// underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
73-
// rows in each partition.
74-
val data =
75-
sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id")
76-
val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
7771

78-
assert(splits.length == 2, "wrong number of splits")
72+
def testNonOverlappingSplits(data: DataFrame): Unit = {
73+
val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
74+
assert(splits.length == 2, "wrong number of splits")
75+
76+
// Verify that the splits span the entire dataset
77+
assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)
7978

80-
// Verify that the splits span the entire dataset
81-
assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)
79+
// Verify that the splits don't overlap
80+
assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty)
8281

83-
// Verify that the splits don't overlap
84-
assert(splits(0).intersect(splits(1)).collect().isEmpty)
82+
// Verify that the results are deterministic across multiple runs
83+
val firstRun = splits.toSeq.map(_.collect().toSeq)
84+
val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
85+
assert(firstRun == secondRun)
86+
}
8587

86-
// Verify that the results are deterministic across multiple runs
87-
val firstRun = splits.toSeq.map(_.collect().toSeq)
88-
val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
89-
assert(firstRun == secondRun)
88+
// This test ensures that randomSplit does not create overlapping splits even when the
89+
// underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
90+
// rows in each partition.
91+
val dataWithInts = sparkContext.parallelize(1 to 600, 2)
92+
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int")
93+
val dataWithMaps = sparkContext.parallelize(1 to 600, 2)
94+
.map(i => (i, Map(i -> i.toString)))
95+
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map")
96+
val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2)
97+
.map(i => (i, Array(Map(i -> i.toString))))
98+
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps")
99+
100+
testNonOverlappingSplits(dataWithInts)
101+
testNonOverlappingSplits(dataWithMaps)
102+
testNonOverlappingSplits(dataWithArrayOfMaps)
90103
}
91104

92105
test("pearson correlation") {

0 commit comments

Comments
 (0)