Skip to content

[SPARK-2724] Python version of RandomRDDGenerators #1628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.random.{RandomRDDGenerators => RG}
import org.apache.spark.mllib.recommendation._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -453,4 +455,99 @@ class PythonMLLibAPI extends Serializable {
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
}

// Used by the *RDD methods to get default seed if not passed in from pyspark
private def getSeedOrDefault(seed: java.lang.Long): Long = {
if (seed == null) Utils.random.nextLong else seed
}

// Used by *RDD methods to get default numPartitions if not passed in from pyspark
private def getNumPartitionsOrDefault(numPartitions: java.lang.Integer,
jsc: JavaSparkContext): Int = {
if (numPartitions == null) {
jsc.sc.defaultParallelism
} else {
numPartitions
}
}

// Note: for the following methods, numPartitions and seed are boxed to allow nulls to be passed
// in for either argument from pyspark

/**
* Java stub for Python mllib RandomRDDGenerators.uniformRDD()
*/
def uniformRDD(jsc: JavaSparkContext,
size: Long,
numPartitions: java.lang.Integer,
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
RG.uniformRDD(jsc.sc, size, parts, s).map(serializeDouble)
}

/**
* Java stub for Python mllib RandomRDDGenerators.normalRDD()
*/
def normalRDD(jsc: JavaSparkContext,
size: Long,
numPartitions: java.lang.Integer,
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
RG.normalRDD(jsc.sc, size, parts, s).map(serializeDouble)
}

/**
* Java stub for Python mllib RandomRDDGenerators.poissonRDD()
*/
def poissonRDD(jsc: JavaSparkContext,
mean: Double,
size: Long,
numPartitions: java.lang.Integer,
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
RG.poissonRDD(jsc.sc, mean, size, parts, s).map(serializeDouble)
}

/**
* Java stub for Python mllib RandomRDDGenerators.uniformVectorRDD()
*/
def uniformVectorRDD(jsc: JavaSparkContext,
numRows: Long,
numCols: Int,
numPartitions: java.lang.Integer,
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector)
}

/**
* Java stub for Python mllib RandomRDDGenerators.normalVectorRDD()
*/
def normalVectorRDD(jsc: JavaSparkContext,
numRows: Long,
numCols: Int,
numPartitions: java.lang.Integer,
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector)
}

/**
* Java stub for Python mllib RandomRDDGenerators.poissonVectorRDD()
*/
def poissonVectorRDD(jsc: JavaSparkContext,
mean: Double,
numRows: Long,
numCols: Int,
numPartitions: java.lang.Integer,
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(serializeDoubleVector)
}
}
Loading