Skip to content

Commit

Permalink
SPARK-1438 RDD.sample() make seed param optional
Browse files Browse the repository at this point in the history
copying form previous pull request #462

Its probably better to let the underlying language implementation take care of the default . This was easier to do with python as the default value for seed in random and numpy random is None.

In Scala/Java side it might mean propagating an Option or null(oh no!) down the chain until where the Random is constructed. But, looks like the convention in some other methods was to use System.nanoTime. So, followed that convention.

Conflict with overloaded method in sql.SchemaRDD.sample which also defines default params.
sample(fraction, withReplacement=false, seed=math.random)
Scala does not allow more than one overloaded to have default params. I believe the author intended to override the RDD.sample method and not overload it. So, changed it.

If backward compatible is important, 3 new method can be introduced (without default params) like this
sample(fraction)
sample(fraction, withReplacement)
sample(fraction, withReplacement, seed)

Added some tests for the scala RDD takeSample method.

Author: Arun Ramakrishnan <smartnut007@gmail.com>

This patch had conflicts when merged, resolved by
Committer: Matei Zaharia <matei@databricks.com>

Closes #477 from smartnut007/master and squashes the following commits:

07bb06e [Arun Ramakrishnan] SPARK-1438 fixing more space formatting issues
b9ebfe2 [Arun Ramakrishnan] SPARK-1438 removing redundant import of random in python rddsampler
8d05b1a [Arun Ramakrishnan] SPARK-1438 RDD . Replace System.nanoTime with a Random generated number. python: use a separate instance of Random instead of seeding language api global Random instance.
69619c6 [Arun Ramakrishnan] SPARK-1438 fix spacing issue
0c247db [Arun Ramakrishnan] SPARK-1438 RDD language apis to support optional seed in RDD methods sample/takeSample

(cherry picked from commit 35e3d19)
Signed-off-by: Matei Zaharia <matei@databricks.com>
  • Loading branch information
arun-rama authored and mateiz committed Apr 25, 2014
1 parent a1f8779 commit 521d435
Show file tree
Hide file tree
Showing 13 changed files with 88 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.StatCounter
import org.apache.spark.util.Utils

class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, JavaDoubleRDD] {

Expand Down Expand Up @@ -133,7 +134,13 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: JDouble, seed: Int): JavaDoubleRDD =
def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD =
sample(withReplacement, fraction, Utils.random.nextLong)

/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: JDouble, seed: Long): JavaDoubleRDD =
fromRDD(srdd.sample(withReplacement, fraction, seed))

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.api.java.function.{Function => JFunction, Function2 => J
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.rdd.{OrderedRDDFunctions, RDD}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils

class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
(implicit val kClassTag: ClassTag[K], implicit val vClassTag: ClassTag[V])
Expand Down Expand Up @@ -119,7 +120,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
def sample(withReplacement: Boolean, fraction: Double): JavaPairRDD[K, V] =
sample(withReplacement, fraction, Utils.random.nextLong)

/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed))

/**
Expand Down
9 changes: 8 additions & 1 deletion core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark._
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils

class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
extends JavaRDDLike[T, JavaRDD[T]] {
Expand Down Expand Up @@ -98,7 +99,13 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] =
sample(withReplacement, fraction, Utils.random.nextLong)

/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] =
wrapRDD(rdd.sample(withReplacement, fraction, seed))

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.api.java.function.{Function => JFunction, Function2 => J
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils

trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def wrapRDD(rdd: RDD[T]): This
Expand Down Expand Up @@ -394,7 +395,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
new java.util.ArrayList(arr)
}

def takeSample(withReplacement: Boolean, num: Int, seed: Int): JList[T] = {
def takeSample(withReplacement: Boolean, num: Int): JList[T] =
takeSample(withReplacement, num, Utils.random.nextLong)

def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = {
import scala.collection.JavaConversions._
val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq
new java.util.ArrayList(arr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.random.RandomSampler
import org.apache.spark.util.Utils

private[spark]
class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
Expand All @@ -38,14 +39,14 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
*
* @param prev RDD to be sampled
* @param sampler a random sampler
* @param seed random seed, default to System.nanoTime
* @param seed random seed
* @tparam T input RDD item type
* @tparam U sampled RDD item type
*/
private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
prev: RDD[T],
sampler: RandomSampler[T, U],
@transient seed: Long = System.nanoTime)
@transient seed: Long = Utils.random.nextLong)
extends RDD[U](prev) {

override def getPartitions: Array[Partition] = {
Expand Down
11 changes: 7 additions & 4 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,9 @@ abstract class RDD[T: ClassTag](
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = {
def sample(withReplacement: Boolean,
fraction: Double,
seed: Long = Utils.random.nextLong): RDD[T] = {
require(fraction >= 0.0, "Invalid fraction value: " + fraction)
if (withReplacement) {
new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
Expand All @@ -354,19 +356,20 @@ abstract class RDD[T: ClassTag](
* Randomly splits this RDD with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1
* @param seed random seed, default to System.nanoTime
* @param seed random seed
*
* @return split RDDs in an array
*/
def randomSplit(weights: Array[Double], seed: Long = System.nanoTime): Array[RDD[T]] = {
def randomSplit(weights: Array[Double], seed: Long = Utils.random.nextLong): Array[RDD[T]] = {
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), seed)
}.toArray
}

def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] =
{
var fraction = 0.0
var total = 0
val multiplier = 3.0
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream,
private[spark] object Utils extends Logging {

val osName = System.getProperty("os.name")

val random = new Random()

/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
Expand Down
21 changes: 20 additions & 1 deletion core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,13 @@ class RDDSuite extends FunSuite with SharedSparkContext {

test("takeSample") {
val data = sc.parallelize(1 to 100, 2)


for (num <- List(5, 20, 100)) {
val sample = data.takeSample(withReplacement=false, num=num)
assert(sample.size === num) // Got exactly num elements
assert(sample.toSet.size === num) // Elements are distinct
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
Expand All @@ -481,6 +487,19 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(sample.size === 20) // Got exactly 20 elements
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
}
{
val sample = data.takeSample(withReplacement=true, num=20)
assert(sample.size === 20) // Got exactly 100 elements
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
}
{
val sample = data.takeSample(withReplacement=true, num=100)
assert(sample.size === 100) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 100, seed)
assert(sample.size === 100) // Got exactly 100 elements
Expand Down
13 changes: 6 additions & 7 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from threading import Thread
import warnings
import heapq
from random import Random

from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
Expand Down Expand Up @@ -332,7 +333,7 @@ def distinct(self):
.reduceByKey(lambda x, _: x) \
.map(lambda (x, _): x)

def sample(self, withReplacement, fraction, seed):
def sample(self, withReplacement, fraction, seed=None):
"""
Return a sampled subset of this RDD (relies on numpy and falls back
on default random generator if numpy is unavailable).
Expand All @@ -344,7 +345,7 @@ def sample(self, withReplacement, fraction, seed):
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)

# this is ported from scala/spark/RDD.scala
def takeSample(self, withReplacement, num, seed):
def takeSample(self, withReplacement, num, seed=None):
"""
Return a fixed-size sampled subset of this RDD (currently requires numpy).
Expand Down Expand Up @@ -381,13 +382,11 @@ def takeSample(self, withReplacement, num, seed):
# If the first sample didn't turn out large enough, keep trying to take samples;
# this shouldn't happen often because we use a big multiplier for their initial size.
# See: scala/spark/RDD.scala
rand = Random(seed)
while len(samples) < total:
if seed > sys.maxint - 2:
seed = -1
seed += 1
samples = self.sample(withReplacement, fraction, seed).collect()
samples = self.sample(withReplacement, fraction, rand.randint(0, sys.maxint)).collect()

sampler = RDDSampler(withReplacement, fraction, seed+1)
sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint))
sampler.shuffle(samples)
return samples[0:total]

Expand Down
31 changes: 14 additions & 17 deletions python/pyspark/rddsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
import random

class RDDSampler(object):
def __init__(self, withReplacement, fraction, seed):
def __init__(self, withReplacement, fraction, seed=None):
try:
import numpy
self._use_numpy = True
except ImportError:
print >> sys.stderr, "NumPy does not appear to be installed. Falling back to default random generator for sampling."
self._use_numpy = False

self._seed = seed
self._seed = seed if seed is not None else random.randint(0, sys.maxint)
self._withReplacement = withReplacement
self._fraction = fraction
self._random = None
Expand All @@ -38,17 +38,14 @@ def initRandomGenerator(self, split):
if self._use_numpy:
import numpy
self._random = numpy.random.RandomState(self._seed)
for _ in range(0, split):
# discard the next few values in the sequence to have a
# different seed for the different splits
self._random.randint(sys.maxint)
else:
import random
random.seed(self._seed)
for _ in range(0, split):
# discard the next few values in the sequence to have a
# different seed for the different splits
random.randint(0, sys.maxint)
self._random = random.Random(self._seed)

for _ in range(0, split):
# discard the next few values in the sequence to have a
# different seed for the different splits
self._random.randint(0, sys.maxint)

self._split = split
self._rand_initialized = True

Expand All @@ -59,7 +56,7 @@ def getUniformSample(self, split):
if self._use_numpy:
return self._random.random_sample()
else:
return random.uniform(0.0, 1.0)
return self._random.uniform(0.0, 1.0)

def getPoissonSample(self, split, mean):
if not self._rand_initialized or split != self._split:
Expand All @@ -73,26 +70,26 @@ def getPoissonSample(self, split, mean):
num_arrivals = 1
cur_time = 0.0

cur_time += random.expovariate(mean)
cur_time += self._random.expovariate(mean)

if cur_time > 1.0:
return 0

while(cur_time <= 1.0):
cur_time += random.expovariate(mean)
cur_time += self._random.expovariate(mean)
num_arrivals += 1

return (num_arrivals - 1)

def shuffle(self, vals):
if self._random == None or split != self._split:
if self._random == None:
self.initRandomGenerator(0) # this should only ever called on the master so
# the split does not matter

if self._use_numpy:
self._random.shuffle(vals)
else:
random.shuffle(vals, self._random)
self._random.shuffle(vals, self._random.random)

def func(self, split, iterator):
if self._withReplacement:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode {
def references = Set.empty
}

case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: LogicalPlan)
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
extends UnaryNode {

def output = child.output
Expand Down
5 changes: 3 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,11 @@ class SchemaRDD(
* @group Query
*/
@Experimental
override
def sample(
fraction: Double,
withReplacement: Boolean = true,
seed: Int = (math.random * 1000).toInt) =
fraction: Double,
seed: Long) =
new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan))

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
* :: DeveloperApi ::
*/
@DeveloperApi
case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: SparkPlan)
extends UnaryNode {

case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)
extends UnaryNode
{
override def output = child.output

// TODO: How to pick seed?
Expand Down

0 comments on commit 521d435

Please sign in to comment.