Skip to content

Commit 5553198

Browse files
brkyvzrxin
authored andcommitted
[SPARK-7156][SQL] Addressed follow up comments for randomSplit
small fixes regarding comments in PR #5761 cc rxin Author: Burak Yavuz <brkyvz@gmail.com> Closes #5795 from brkyvz/split-followup and squashes the following commits: 369c522 [Burak Yavuz] changed wording a little 1ea456f [Burak Yavuz] Addressed follow up comments
1 parent 7143f6e commit 5553198

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,10 @@ def sample(self, withReplacement, fraction, seed=None):
437437
def randomSplit(self, weights, seed=None):
438438
"""Randomly splits this :class:`DataFrame` with the provided weights.
439439
440+
:param weights: list of doubles as weights with which to split the DataFrame. Weights will
441+
be normalized if they don't sum up to 1.0.
442+
:param seed: The seed for sampling.
443+
440444
>>> splits = df4.randomSplit([1.0, 2.0], 24)
441445
>>> splits[0].count()
442446
1
@@ -445,7 +449,8 @@ def randomSplit(self, weights, seed=None):
445449
3
446450
"""
447451
for w in weights:
448-
assert w >= 0.0, "Negative weight value: %s" % w
452+
if w < 0.0:
453+
raise ValueError("Weights must be positive. Found weight value: %s" % w)
449454
seed = seed if seed is not None else random.randint(0, sys.maxsize)
450455
rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed))
451456
return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ class DataFrame private[sql](
752752
* @param seed Seed for sampling.
753753
* @group dfops
754754
*/
755-
def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = {
755+
private[spark] def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = {
756756
randomSplit(weights.toArray, seed)
757757
}
758758

0 commit comments

Comments
 (0)