-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SPARK-1438 RDD.sample() make seed param optional #477
Changes from 2 commits
0c247db
69619c6
8d05b1a
b9ebfe2
07bb06e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -465,7 +465,13 @@ class RDDSuite extends FunSuite with SharedSparkContext { | |
|
||
test("takeSample") { | ||
val data = sc.parallelize(1 to 100, 2) | ||
|
||
|
||
for (num <- List(5,20,100)) { | ||
val sample = data.takeSample(withReplacement=false, num=num) | ||
assert(sample.size === num) // Got exactly num elements | ||
assert(sample.toSet.size === num) // Elements are distinct | ||
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indenting seems off here, there seem to be some tabs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will take care of the indent |
||
for (seed <- 1 to 5) { | ||
val sample = data.takeSample(withReplacement=false, 20, seed) | ||
assert(sample.size === 20) // Got exactly 20 elements | ||
|
@@ -483,6 +489,19 @@ class RDDSuite extends FunSuite with SharedSparkContext { | |
assert(sample.size === 20) // Got exactly 20 elements | ||
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") | ||
} | ||
{ | ||
val sample = data.takeSample(withReplacement=true, num=20) | ||
assert(sample.size === 20) // Got exactly 100 elements | ||
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements") | ||
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") | ||
} | ||
{ | ||
val sample = data.takeSample(withReplacement=true, num=100) | ||
assert(sample.size === 100) // Got exactly 100 elements | ||
// Chance of getting all distinct elements is astronomically low, so test we got < 100 | ||
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") | ||
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") | ||
} | ||
for (seed <- 1 to 5) { | ||
val sample = data.takeSample(withReplacement=true, 100, seed) | ||
assert(sample.size === 100) // Got exactly 100 elements | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
from threading import Thread | ||
import warnings | ||
import heapq | ||
import random | ||
|
||
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ | ||
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long | ||
|
@@ -332,7 +333,7 @@ def distinct(self): | |
.reduceByKey(lambda x, _: x) \ | ||
.map(lambda (x, _): x) | ||
|
||
def sample(self, withReplacement, fraction, seed): | ||
def sample(self, withReplacement, fraction, seed=None): | ||
""" | ||
Return a sampled subset of this RDD (relies on numpy and falls back | ||
on default random generator if numpy is unavailable). | ||
|
@@ -344,7 +345,7 @@ def sample(self, withReplacement, fraction, seed): | |
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) | ||
|
||
# this is ported from scala/spark/RDD.scala | ||
def takeSample(self, withReplacement, num, seed): | ||
def takeSample(self, withReplacement, num, seed=None): | ||
""" | ||
Return a fixed-size sampled subset of this RDD (currently requires numpy). | ||
|
||
|
@@ -381,13 +382,11 @@ def takeSample(self, withReplacement, num, seed): | |
# If the first sample didn't turn out large enough, keep trying to take samples; | ||
# this shouldn't happen often because we use a big multiplier for their initial size. | ||
# See: scala/spark/RDD.scala | ||
random.seed(seed) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the global random object? Library code should not be setting the seed and then calling randint. Is there no equivalent of java.util.Random that you can create and use here? |
||
while len(samples) < total: | ||
if seed > sys.maxint - 2: | ||
seed = -1 | ||
seed += 1 | ||
samples = self.sample(withReplacement, fraction, seed).collect() | ||
samples = self.sample(withReplacement, fraction, random.randint(0,sys.maxint)).collect() | ||
|
||
sampler = RDDSampler(withReplacement, fraction, seed+1) | ||
sampler = RDDSampler(withReplacement, fraction, random.randint(0,sys.maxint)) | ||
sampler.shuffle(samples) | ||
return samples[0:total] | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,7 @@ | |
import random | ||
|
||
class RDDSampler(object): | ||
def __init__(self, withReplacement, fraction, seed): | ||
def __init__(self, withReplacement, fraction, seed=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't you need to do something later to deal with seed being None? Does random.seed(None) do the right thing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mateiz both the numpy random and python language random functions should handle None fine. import numpy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @smartnut007, @mateiz, if we should not use nanoTime as a seed, we should not pass none to random.seed(None). random.seed(None) will use time.time(). As reply for your very first comment, I think consistency is important. But it's just my opinion, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, in that case just use Python's built-in random, or create a Random object if there isn't a global one you can call. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can simply add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @advancedxy @mateiz from binascii import hexlify as _hexlify I generally think its better to leave these things to the language implementors. But, if we need the code to look similar, then I would follow the last suggestion in RDDSampler. Can you guys let me know which one ? |
||
try: | ||
import numpy | ||
self._use_numpy = True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put spaces after the commas here