-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SPARK-1438 RDD.sample() make seed param optional #477
Changes from 4 commits
0c247db
69619c6
8d05b1a
b9ebfe2
07bb06e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -465,7 +465,13 @@ class RDDSuite extends FunSuite with SharedSparkContext { | |
|
||
test("takeSample") { | ||
val data = sc.parallelize(1 to 100, 2) | ||
|
||
|
||
for (num <- List(5,20,100)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put spaces after the commas here |
||
val sample = data.takeSample(withReplacement=false, num=num) | ||
assert(sample.size === num) // Got exactly num elements | ||
assert(sample.toSet.size === num) // Elements are distinct | ||
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") | ||
} | ||
for (seed <- 1 to 5) { | ||
val sample = data.takeSample(withReplacement=false, 20, seed) | ||
assert(sample.size === 20) // Got exactly 20 elements | ||
|
@@ -483,6 +489,19 @@ class RDDSuite extends FunSuite with SharedSparkContext { | |
assert(sample.size === 20) // Got exactly 20 elements | ||
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") | ||
} | ||
{ | ||
val sample = data.takeSample(withReplacement=true, num=20) | ||
assert(sample.size === 20) // Got exactly 100 elements | ||
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements") | ||
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") | ||
} | ||
{ | ||
val sample = data.takeSample(withReplacement=true, num=100) | ||
assert(sample.size === 100) // Got exactly 100 elements | ||
// Chance of getting all distinct elements is astronomically low, so test we got < 100 | ||
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") | ||
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") | ||
} | ||
for (seed <- 1 to 5) { | ||
val sample = data.takeSample(withReplacement=true, 100, seed) | ||
assert(sample.size === 100) // Got exactly 100 elements | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
from threading import Thread | ||
import warnings | ||
import heapq | ||
from random import Random | ||
|
||
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ | ||
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long | ||
|
@@ -332,7 +333,7 @@ def distinct(self): | |
.reduceByKey(lambda x, _: x) \ | ||
.map(lambda (x, _): x) | ||
|
||
def sample(self, withReplacement, fraction, seed): | ||
def sample(self, withReplacement, fraction, seed=None): | ||
""" | ||
Return a sampled subset of this RDD (relies on numpy and falls back | ||
on default random generator if numpy is unavailable). | ||
|
@@ -344,7 +345,7 @@ def sample(self, withReplacement, fraction, seed): | |
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) | ||
|
||
# this is ported from scala/spark/RDD.scala | ||
def takeSample(self, withReplacement, num, seed): | ||
def takeSample(self, withReplacement, num, seed=None): | ||
""" | ||
Return a fixed-size sampled subset of this RDD (currently requires numpy). | ||
|
||
|
@@ -381,13 +382,11 @@ def takeSample(self, withReplacement, num, seed): | |
# If the first sample didn't turn out large enough, keep trying to take samples; | ||
# this shouldn't happen often because we use a big multiplier for their initial size. | ||
# See: scala/spark/RDD.scala | ||
rand = Random(seed) | ||
while len(samples) < total: | ||
if seed > sys.maxint - 2: | ||
seed = -1 | ||
seed += 1 | ||
samples = self.sample(withReplacement, fraction, seed).collect() | ||
samples = self.sample(withReplacement, fraction, rand.randint(0,sys.maxint)).collect() | ||
|
||
sampler = RDDSampler(withReplacement, fraction, seed+1) | ||
sampler = RDDSampler(withReplacement, fraction, rand.randint(0,sys.maxint)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put spaces after the comma here and in other instances of |
||
sampler.shuffle(samples) | ||
return samples[0:total] | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,15 +19,15 @@ | |
import random | ||
|
||
class RDDSampler(object): | ||
def __init__(self, withReplacement, fraction, seed): | ||
def __init__(self, withReplacement, fraction, seed=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't you need to do something later to deal with seed being None? Does random.seed(None) do the right thing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mateiz both the numpy random and python language random functions should handle None fine. import numpy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @smartnut007, @mateiz, if we should not use nanoTime as a seed, we should not pass none to random.seed(None). random.seed(None) will use time.time(). As reply for your very first comment, I think consistency is important. But it's just my opinion, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, in that case just use Python's built-in random, or create a Random object if there isn't a global one you can call. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can simply add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @advancedxy @mateiz from binascii import hexlify as _hexlify I generally think its better to leave these things to the language implementors. But, if we need the code to look similar, then I would follow the last suggestion in RDDSampler. Can you guys let me know which one ? |
||
try: | ||
import numpy | ||
self._use_numpy = True | ||
except ImportError: | ||
print >> sys.stderr, "NumPy does not appear to be installed. Falling back to default random generator for sampling." | ||
self._use_numpy = False | ||
|
||
self._seed = seed | ||
self._seed = seed if seed is not None else random.randint(0,sys.maxint) | ||
self._withReplacement = withReplacement | ||
self._fraction = fraction | ||
self._random = None | ||
|
@@ -38,17 +38,14 @@ def initRandomGenerator(self, split): | |
if self._use_numpy: | ||
import numpy | ||
self._random = numpy.random.RandomState(self._seed) | ||
for _ in range(0, split): | ||
# discard the next few values in the sequence to have a | ||
# different seed for the different splits | ||
self._random.randint(sys.maxint) | ||
else: | ||
import random | ||
random.seed(self._seed) | ||
for _ in range(0, split): | ||
# discard the next few values in the sequence to have a | ||
# different seed for the different splits | ||
random.randint(0, sys.maxint) | ||
self._random = random.Random(self._seed) | ||
|
||
for _ in range(0, split): | ||
# discard the next few values in the sequence to have a | ||
# different seed for the different splits | ||
self._random.randint(0, sys.maxint) | ||
|
||
self._split = split | ||
self._rand_initialized = True | ||
|
||
|
@@ -59,7 +56,7 @@ def getUniformSample(self, split): | |
if self._use_numpy: | ||
return self._random.random_sample() | ||
else: | ||
return random.uniform(0.0, 1.0) | ||
return self._random.uniform(0.0, 1.0) | ||
|
||
def getPoissonSample(self, split, mean): | ||
if not self._rand_initialized or split != self._split: | ||
|
@@ -73,26 +70,26 @@ def getPoissonSample(self, split, mean): | |
num_arrivals = 1 | ||
cur_time = 0.0 | ||
|
||
cur_time += random.expovariate(mean) | ||
cur_time += self._random.expovariate(mean) | ||
|
||
if cur_time > 1.0: | ||
return 0 | ||
|
||
while(cur_time <= 1.0): | ||
cur_time += random.expovariate(mean) | ||
cur_time += self._random.expovariate(mean) | ||
num_arrivals += 1 | ||
|
||
return (num_arrivals - 1) | ||
|
||
def shuffle(self, vals): | ||
if self._random == None or split != self._split: | ||
if self._random == None: | ||
self.initRandomGenerator(0) # this should only ever called on the master so | ||
# the split does not matter | ||
|
||
if self._use_numpy: | ||
self._random.shuffle(vals) | ||
else: | ||
random.shuffle(vals, self._random) | ||
self._random.shuffle(vals, self._random.random) | ||
|
||
def func(self, split, iterator): | ||
if self._withReplacement: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to say what the seed defaults to here since users won't understand it; just say
@param seed random seed
and they can guess that if you don't specify it, we will choose one