18
18
import sys
19
19
import random
20
20
21
+
21
22
class RDDSampler (object ):
22
23
def __init__ (self , withReplacement , fraction , seed = None ):
23
24
try :
24
25
import numpy
25
26
self ._use_numpy = True
26
27
except ImportError :
27
- print >> sys .stderr , "NumPy does not appear to be installed. Falling back to default random generator for sampling."
28
+ print >> sys .stderr , (
29
+ "NumPy does not appear to be installed. "
30
+ "Falling back to default random generator for sampling." )
28
31
self ._use_numpy = False
29
32
30
33
self ._seed = seed if seed is not None else random .randint (0 , sys .maxint )
@@ -61,7 +64,7 @@ def getUniformSample(self, split):
61
64
def getPoissonSample (self , split , mean ):
62
65
if not self ._rand_initialized or split != self ._split :
63
66
self .initRandomGenerator (split )
64
-
67
+
65
68
if self ._use_numpy :
66
69
return self ._random .poisson (mean )
67
70
else :
@@ -80,30 +83,27 @@ def getPoissonSample(self, split, mean):
80
83
num_arrivals += 1
81
84
82
85
return (num_arrivals - 1 )
83
-
86
+
84
87
def shuffle (self , vals ):
85
88
if self ._random is None :
86
89
self .initRandomGenerator (0 ) # this should only ever called on the master so
87
90
# the split does not matter
88
-
91
+
89
92
if self ._use_numpy :
90
93
self ._random .shuffle (vals )
91
94
else :
92
95
self ._random .shuffle (vals , self ._random .random )
93
96
94
97
def func (self , split , iterator ):
95
- if self ._withReplacement :
98
+ if self ._withReplacement :
96
99
for obj in iterator :
97
- # For large datasets, the expected number of occurrences of each element in a sample with
98
- # replacement is Poisson(frac). We use that to get a count for each element.
99
- count = self .getPoissonSample (split , mean = self ._fraction )
100
+ # For large datasets, the expected number of occurrences of each element in
101
+ # a sample with replacement is Poisson(frac). We use that to get a count for
102
+ # each element.
103
+ count = self .getPoissonSample (split , mean = self ._fraction )
100
104
for _ in range (0 , count ):
101
105
yield obj
102
106
else :
103
107
for obj in iterator :
104
108
if self .getUniformSample (split ) <= self ._fraction :
105
109
yield obj
106
-
107
-
108
-
109
-
0 commit comments