Skip to content

Commit d56cacb

Browse files
committed
impl with RandomRDD
1 parent 92d6f1c commit d56cacb

File tree

6 files changed

+552
-119
lines changed

6 files changed

+552
-119
lines changed

mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,54 +22,84 @@ import scala.util.Random
2222
import cern.jet.random.Poisson
2323
import cern.jet.random.engine.DRand
2424

25-
import org.apache.spark.util.random.Pseudorandom
25+
import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom}
2626

2727
/**
28-
* Trait for random number generators that generate i.i.d values from a distribution
28+
* Trait for random number generators that generate i.i.d values from a distribution.
2929
*/
30-
trait DistributionGenerator extends Pseudorandom with Cloneable {
30+
trait DistributionGenerator extends Pseudorandom with Serializable {
3131

3232
/**
33-
* @return An i.i.d sample as a Double from an underlying distribution
33+
* @return An i.i.d sample as a Double from an underlying distribution.
3434
*/
3535
def nextValue(): Double
3636

37-
override def clone(): DistributionGenerator = super.clone().asInstanceOf[DistributionGenerator]
37+
/**
38+
* @return A copy of the DistributionGenerator with a new instance of the rng object used in the
39+
* class when applicable. Each partition has a unique seed and therefore requires its
40+
* own instance of the DistributionGenerator.
41+
*/
42+
def newInstance(): DistributionGenerator
3843
}
3944

40-
class NormalGenerator(val mean: Double = 0.0, val stddev: Double = 1.0)
41-
extends DistributionGenerator {
45+
/**
46+
* Generates i.i.d. samples from U[0.0, 1.0]
47+
*/
48+
class UniformGenerator() extends DistributionGenerator {
49+
// XORShiftRandom for better performance. Thread safety isn't necessary here.
50+
private val random = new XORShiftRandom()
4251

43-
require(stddev >= 0.0, "Standard deviation cannot be negative.")
52+
/**
53+
* @return An i.i.d sample as a Double from U[0.0, 1.0].
54+
*/
55+
override def nextValue(): Double = {
56+
random.nextDouble()
57+
}
4458

45-
private val random = new Random()
46-
private val standard = mean == 0.0 && stddev == 1.0
59+
/** Set random seed. */
60+
override def setSeed(seed: Long) = random.setSeed(seed)
61+
62+
override def newInstance(): UniformGenerator = new UniformGenerator()
63+
}
64+
65+
/**
66+
* Generates i.i.d. samples from the Standard Normal Distribution.
67+
*/
68+
class StandardNormalGenerator() extends DistributionGenerator {
69+
// XORShiftRandom for better performance. Thread safety isn't necessary here.
70+
private val random = new XORShiftRandom()
4771

4872
/**
49-
* @return An i.i.d sample as a Double from the Normal distribution
73+
* @return An i.i.d sample as a Double from the Normal distribution.
5074
*/
5175
override def nextValue(): Double = {
52-
if (standard) {
5376
random.nextGaussian()
54-
} else {
55-
mean + stddev * random.nextGaussian()
56-
}
5777
}
5878

5979
/** Set random seed. */
6080
override def setSeed(seed: Long) = random.setSeed(seed)
81+
82+
override def newInstance(): StandardNormalGenerator = new StandardNormalGenerator()
6183
}
6284

63-
class PoissonGenerator(val lambda: Double = 0.0) extends DistributionGenerator {
85+
/**
86+
* Generates i.i.d. samples from the Poisson distribution with the given mean.
87+
*
88+
* @param mean mean for the Poisson distribution.
89+
*/
90+
class PoissonGenerator(val mean: Double) extends DistributionGenerator {
91+
92+
private var rng = new Poisson(mean, new DRand)
6493

65-
private var rng = new Poisson(lambda, new DRand)
6694
/**
67-
* @return An i.i.d sample as a Double from the Poisson distribution
95+
* @return An i.i.d sample as a Double from the Poisson distribution.
6896
*/
6997
override def nextValue(): Double = rng.nextDouble()
7098

7199
/** Set random seed. */
72100
override def setSeed(seed: Long) {
73-
rng = new Poisson(lambda, new DRand(seed.toInt))
101+
rng = new Poisson(mean, new DRand(seed.toInt))
74102
}
103+
104+
override def newInstance(): PoissonGenerator = new PoissonGenerator(mean)
75105
}
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.random
19+
20+
import org.apache.spark.SparkContext
21+
import org.apache.spark.mllib.linalg.Vector
22+
import org.apache.spark.mllib.rdd.{RandomVectorRDD, RandomRDD}
23+
import org.apache.spark.rdd.RDD
24+
import org.apache.spark.util.Utils
25+
26+
// TODO add Scaladocs once API fully approved
27+
object RandomRDDGenerators {
28+
29+
def uniformRDD(sc: SparkContext, size: Long, numPartitions: Int, seed: Long): RDD[Double] = {
30+
val uniform = new UniformGenerator()
31+
randomRDD(sc, size, numPartitions, uniform, seed)
32+
}
33+
34+
def uniformRDD(sc: SparkContext, size: Long, seed: Long): RDD[Double] = {
35+
uniformRDD(sc, size, sc.defaultParallelism, seed)
36+
}
37+
38+
def uniformRDD(sc: SparkContext, size: Long, numPartitions: Int): RDD[Double] = {
39+
uniformRDD(sc, size, numPartitions, Utils.random.nextLong)
40+
}
41+
42+
def uniformRDD(sc: SparkContext, size: Long): RDD[Double] = {
43+
uniformRDD(sc, size, sc.defaultParallelism, Utils.random.nextLong)
44+
}
45+
46+
def normalRDD(sc: SparkContext, size: Long, numPartitions: Int, seed: Long): RDD[Double] = {
47+
val normal = new StandardNormalGenerator()
48+
randomRDD(sc, size, numPartitions, normal, seed)
49+
}
50+
51+
def normalRDD(sc: SparkContext, size: Long, seed: Long): RDD[Double] = {
52+
normalRDD(sc, size, sc.defaultParallelism, seed)
53+
}
54+
55+
def normalRDD(sc: SparkContext, size: Long, numPartitions: Int): RDD[Double] = {
56+
normalRDD(sc, size, numPartitions, Utils.random.nextLong)
57+
}
58+
59+
def normalRDD(sc: SparkContext, size: Long): RDD[Double] = {
60+
normalRDD(sc, size, sc.defaultParallelism, Utils.random.nextLong)
61+
}
62+
63+
def poissonRDD(sc: SparkContext,
64+
size: Long,
65+
numPartitions: Int,
66+
mean: Double,
67+
seed: Long): RDD[Double] = {
68+
val poisson = new PoissonGenerator(mean)
69+
randomRDD(sc, size, numPartitions, poisson, seed)
70+
}
71+
72+
def poissonRDD(sc: SparkContext, size: Long, mean: Double, seed: Long): RDD[Double] = {
73+
poissonRDD(sc, size, sc.defaultParallelism, mean, seed)
74+
}
75+
76+
def poissonRDD(sc: SparkContext, size: Long, numPartitions: Int, mean: Double): RDD[Double] = {
77+
poissonRDD(sc, size, numPartitions, mean, Utils.random.nextLong)
78+
}
79+
80+
def poissonRDD(sc: SparkContext, size: Long, mean: Double): RDD[Double] = {
81+
poissonRDD(sc, size, sc.defaultParallelism, mean, Utils.random.nextLong)
82+
}
83+
84+
def randomRDD(sc: SparkContext,
85+
size: Long,
86+
numPartitions: Int,
87+
distribution: DistributionGenerator,
88+
seed: Long): RDD[Double] = {
89+
new RandomRDD(sc, size, numPartitions, distribution, seed)
90+
}
91+
92+
def randomRDD(sc: SparkContext,
93+
size: Long,
94+
distribution: DistributionGenerator,
95+
seed: Long): RDD[Double] = {
96+
randomRDD(sc, size, sc.defaultParallelism, distribution, seed)
97+
}
98+
99+
def randomRDD(sc: SparkContext,
100+
size: Long,
101+
numPartitions: Int,
102+
distribution: DistributionGenerator): RDD[Double] = {
103+
randomRDD(sc, size, numPartitions, distribution, Utils.random.nextLong)
104+
}
105+
106+
def randomRDD(sc: SparkContext,
107+
size: Long,
108+
distribution: DistributionGenerator): RDD[Double] = {
109+
randomRDD(sc, size, sc.defaultParallelism, distribution, Utils.random.nextLong)
110+
}
111+
112+
// TODO Generator RDD[Vector] from multivariate distribution
113+
114+
def uniformVectorRDD(sc: SparkContext,
115+
numRows: Long,
116+
numColumns: Int,
117+
numPartitions: Int,
118+
seed: Long): RDD[Vector] = {
119+
val uniform = new UniformGenerator()
120+
randomVectorRDD(sc, numRows, numColumns, numPartitions, uniform, seed)
121+
}
122+
123+
def uniformVectorRDD(sc: SparkContext,
124+
numRows: Long,
125+
numColumns: Int,
126+
seed: Long): RDD[Vector] = {
127+
uniformVectorRDD(sc, numRows, numColumns, sc.defaultParallelism, seed)
128+
}
129+
130+
def uniformVectorRDD(sc: SparkContext,
131+
numRows: Long,
132+
numColumns: Int,
133+
numPartitions: Int): RDD[Vector] = {
134+
uniformVectorRDD(sc, numRows, numColumns, numPartitions, Utils.random.nextLong)
135+
}
136+
137+
def uniformVectorRDD(sc: SparkContext, numRows: Long, numColumns: Int): RDD[Vector] = {
138+
uniformVectorRDD(sc, numRows, numColumns, sc.defaultParallelism, Utils.random.nextLong)
139+
}
140+
141+
def normalVectorRDD(sc: SparkContext,
142+
numRows: Long,
143+
numColumns: Int,
144+
numPartitions: Int,
145+
seed: Long): RDD[Vector] = {
146+
val uniform = new StandardNormalGenerator()
147+
randomVectorRDD(sc, numRows, numColumns, numPartitions, uniform, seed)
148+
}
149+
150+
def normalVectorRDD(sc: SparkContext,
151+
numRows: Long,
152+
numColumns: Int,
153+
seed: Long): RDD[Vector] = {
154+
normalVectorRDD(sc, numRows, numColumns, sc.defaultParallelism, seed)
155+
}
156+
157+
def normalVectorRDD(sc: SparkContext,
158+
numRows: Long,
159+
numColumns: Int,
160+
numPartitions: Int): RDD[Vector] = {
161+
normalVectorRDD(sc, numRows, numColumns, numPartitions, Utils.random.nextLong)
162+
}
163+
164+
def normalVectorRDD(sc: SparkContext, numRows: Long, numColumns: Int): RDD[Vector] = {
165+
normalVectorRDD(sc, numRows, numColumns, sc.defaultParallelism, Utils.random.nextLong)
166+
}
167+
168+
def poissonVectorRDD(sc: SparkContext,
169+
numRows: Long,
170+
numColumns: Int,
171+
numPartitions: Int,
172+
mean: Double,
173+
seed: Long): RDD[Vector] = {
174+
val poisson = new PoissonGenerator(mean)
175+
randomVectorRDD(sc, numRows, numColumns, numPartitions, poisson, seed)
176+
}
177+
178+
def poissonVectorRDD(sc: SparkContext,
179+
numRows: Long,
180+
numColumns: Int,
181+
mean: Double,
182+
seed: Long): RDD[Vector] = {
183+
poissonVectorRDD(sc, numRows, numColumns, sc.defaultParallelism, mean, seed)
184+
}
185+
186+
def poissonVectorRDD(sc: SparkContext,
187+
numRows: Long,
188+
numColumns: Int,
189+
numPartitions: Int,
190+
mean: Double): RDD[Vector] = {
191+
poissonVectorRDD(sc, numRows, numColumns, numPartitions, mean, Utils.random.nextLong)
192+
}
193+
194+
def poissonVectorRDD(sc: SparkContext,
195+
numRows: Long,
196+
numColumns: Int,
197+
mean: Double): RDD[Vector] = {
198+
val poisson = new PoissonGenerator(mean)
199+
randomVectorRDD(sc, numRows, numColumns, sc.defaultParallelism, poisson, Utils.random.nextLong)
200+
}
201+
202+
def randomVectorRDD(sc: SparkContext,
203+
numRows: Long,
204+
numColumns: Int,
205+
numPartitions: Int,
206+
rng: DistributionGenerator,
207+
seed: Long): RDD[Vector] = {
208+
new RandomVectorRDD(sc, numRows, numColumns, numPartitions, rng, seed)
209+
}
210+
211+
def randomVectorRDD(sc: SparkContext,
212+
numRows: Long,
213+
numColumns: Int,
214+
rng: DistributionGenerator,
215+
seed: Long): RDD[Vector] = {
216+
randomVectorRDD(sc, numRows, numColumns, sc.defaultParallelism, rng, seed)
217+
}
218+
219+
def randomVectorRDD(sc: SparkContext,
220+
numRows: Long,
221+
numColumns: Int,
222+
numPartitions: Int,
223+
rng: DistributionGenerator): RDD[Vector] = {
224+
randomVectorRDD(sc, numRows, numColumns, numPartitions, rng, Utils.random.nextLong)
225+
}
226+
227+
def randomVectorRDD(sc: SparkContext,
228+
numRows: Long,
229+
numColumns: Int,
230+
rng: DistributionGenerator): RDD[Vector] = {
231+
randomVectorRDD(sc, numRows, numColumns, sc.defaultParallelism, rng, Utils.random.nextLong)
232+
}
233+
}

0 commit comments

Comments
 (0)