Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-1438 RDD.sample() make seed param optional #477

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,13 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: JDouble, seed: Int): JavaDoubleRDD =
def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD =
sample(withReplacement, fraction, System.nanoTime)

/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: JDouble, seed: Long): JavaDoubleRDD =
fromRDD(srdd.sample(withReplacement, fraction, seed))

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
def sample(withReplacement: Boolean, fraction: Double): JavaPairRDD[K, V] =
sample(withReplacement, fraction, System.nanoTime)

/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed))

/**
Expand Down
8 changes: 7 additions & 1 deletion core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,13 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] =
sample(withReplacement, fraction, System.nanoTime)

/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] =
wrapRDD(rdd.sample(withReplacement, fraction, seed))

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
new java.util.ArrayList(arr)
}

def takeSample(withReplacement: Boolean, num: Int, seed: Int): JList[T] = {
def takeSample(withReplacement: Boolean, num: Int): JList[T] =
takeSample(withReplacement, num, System.nanoTime)

def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = {
import scala.collection.JavaConversions._
val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq
new java.util.ArrayList(arr)
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ abstract class RDD[T: ClassTag](
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = {
def sample(withReplacement: Boolean, fraction: Double, seed: Long = System.nanoTime): RDD[T] = {
require(fraction >= 0.0, "Invalid fraction value: " + fraction)
if (withReplacement) {
new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
Expand All @@ -346,7 +346,7 @@ abstract class RDD[T: ClassTag](
}.toArray
}

def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
def takeSample(withReplacement: Boolean, num: Int, seed: Long = System.nanoTime): Array[T] = {
var fraction = 0.0
var total = 0
val multiplier = 3.0
Expand Down
21 changes: 20 additions & 1 deletion core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,13 @@ class RDDSuite extends FunSuite with SharedSparkContext {

test("takeSample") {
val data = sc.parallelize(1 to 100, 2)


for (num <- List(5,20,100)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put spaces after the commas here

val sample = data.takeSample(withReplacement=false, num=num)
assert(sample.size === num) // Got exactly num elements
assert(sample.toSet.size === num) // Elements are distinct
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indenting seems off here, there seem to be some tabs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will take care of the indent

for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
Expand All @@ -483,6 +489,19 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(sample.size === 20) // Got exactly 20 elements
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
}
{
val sample = data.takeSample(withReplacement=true, num=20)
assert(sample.size === 20) // Got exactly 100 elements
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
}
{
val sample = data.takeSample(withReplacement=true, num=100)
assert(sample.size === 100) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 100, seed)
assert(sample.size === 100) // Got exactly 100 elements
Expand Down
13 changes: 6 additions & 7 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from threading import Thread
import warnings
import heapq
import random

from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
Expand Down Expand Up @@ -332,7 +333,7 @@ def distinct(self):
.reduceByKey(lambda x, _: x) \
.map(lambda (x, _): x)

def sample(self, withReplacement, fraction, seed):
def sample(self, withReplacement, fraction, seed=None):
"""
Return a sampled subset of this RDD (relies on numpy and falls back
on default random generator if numpy is unavailable).
Expand All @@ -344,7 +345,7 @@ def sample(self, withReplacement, fraction, seed):
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)

# this is ported from scala/spark/RDD.scala
def takeSample(self, withReplacement, num, seed):
def takeSample(self, withReplacement, num, seed=None):
"""
Return a fixed-size sampled subset of this RDD (currently requires numpy).

Expand Down Expand Up @@ -381,13 +382,11 @@ def takeSample(self, withReplacement, num, seed):
# If the first sample didn't turn out large enough, keep trying to take samples;
# this shouldn't happen often because we use a big multiplier for their initial size.
# See: scala/spark/RDD.scala
random.seed(seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the global random object? Library code should not be setting the seed and then calling randint. Is there no equivalent of java.util.Random that you can create and use here?

while len(samples) < total:
if seed > sys.maxint - 2:
seed = -1
seed += 1
samples = self.sample(withReplacement, fraction, seed).collect()
samples = self.sample(withReplacement, fraction, random.randint(0,sys.maxint)).collect()

sampler = RDDSampler(withReplacement, fraction, seed+1)
sampler = RDDSampler(withReplacement, fraction, random.randint(0,sys.maxint))
sampler.shuffle(samples)
return samples[0:total]

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/rddsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import random

class RDDSampler(object):
def __init__(self, withReplacement, fraction, seed):
def __init__(self, withReplacement, fraction, seed=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you need to do something later to deal with seed being None? Does random.seed(None) do the right thing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mateiz both the numpy random and python language random functions should handle None fine.

import numpy
numpy.random.RandomState(None)
import random
random.seed(None)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@smartnut007, @mateiz, if we should not use nanoTime as a seed, we should not pass none to random.seed(None). random.seed(None) will use time.time().

As reply for your very first comment, I think consistency is important. But it's just my opinion,
others may disagree with me. You should do what you think is right.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, in that case just use Python's built-in random, or create a Random object if there isn't a global one you can call.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can simply add if seed is None: seed = random.random() or whatever is required for it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@advancedxy @mateiz
Actually the default python seed in random.seed(None) is the following snippet and picking random.random() will go through these anyway. ( python 2.7.5 )

from binascii import hexlify as _hexlify
from os import urandom as _urandom
long(_hexlify(_urandom(16)), 16)

I generally think its better to leave these things to the language implementors.

But, if we need the code to look similar, then I would follow the last suggestion in RDDSampler.
if seed is None: seed = random.randint(0, sys.maxint)
here sys.maxint is arch specific and generally 64 bit on 64 bit machines.

Can you guys let me know which one ?

try:
import numpy
self._use_numpy = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode {
def references = Set.empty
}

case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: LogicalPlan)
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
extends UnaryNode {

def output = child.output
Expand Down
5 changes: 3 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,11 @@ class SchemaRDD(
* @group Query
*/
@Experimental
override
def sample(
fraction: Double,
withReplacement: Boolean = true,
seed: Int = (math.random * 1000).toInt) =
fraction: Double,
seed: Long) =
new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan))

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
}
}

case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: SparkPlan)
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)
extends UnaryNode {

override def output = child.output
Expand Down