Skip to content

Commit 43caac9

Browse files
committed
add predict(JavaRDD) to RegressionModel, ClassificationModel, and KMeans
1 parent fbfe69d commit 43caac9

File tree

6 files changed

+76
-2
lines changed

6 files changed

+76
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
package org.apache.spark.mllib.classification
1919

20+
import org.apache.spark.annotation.Experimental
21+
import org.apache.spark.api.java.JavaRDD
2022
import org.apache.spark.mllib.linalg.Vector
2123
import org.apache.spark.rdd.RDD
22-
import org.apache.spark.annotation.Experimental
2324

2425
/**
2526
* :: Experimental ::
@@ -43,4 +44,12 @@ trait ClassificationModel extends Serializable {
4344
* @return predicted category from the trained model
4445
*/
4546
def predict(testData: Vector): Double
47+
48+
/**
49+
* Predict values for examples stored in a JavaRDD.
50+
* @param testData JavaRDD representing data points to be predicted
51+
* @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction
52+
*/
53+
def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
54+
predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
4655
}

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.mllib.clustering
1919

20+
import org.apache.spark.api.java.JavaRDD
2021
import org.apache.spark.rdd.RDD
2122
import org.apache.spark.SparkContext._
2223
import org.apache.spark.mllib.linalg.Vector
@@ -40,6 +41,10 @@ class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Ser
4041
points.map(p => KMeans.findClosest(centersWithNorm, new BreezeVectorWithNorm(p))._1)
4142
}
4243

44+
/** Maps given points to their cluster indices. */
45+
def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
46+
predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
47+
4348
/**
4449
* Return the K-means cost (sum of squared distances of points to their nearest center) for this
4550
* model on the given data.

mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20+
import org.apache.spark.annotation.Experimental
21+
import org.apache.spark.api.java.JavaRDD
2022
import org.apache.spark.rdd.RDD
2123
import org.apache.spark.mllib.linalg.Vector
22-
import org.apache.spark.annotation.Experimental
2324

2425
@Experimental
2526
trait RegressionModel extends Serializable {
@@ -38,4 +39,12 @@ trait RegressionModel extends Serializable {
3839
* @return Double prediction from the trained model
3940
*/
4041
def predict(testData: Vector): Double
42+
43+
/**
44+
* Predict values for examples stored in a JavaRDD.
45+
* @param testData JavaRDD representing data points to be predicted
46+
* @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction
47+
*/
48+
def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
49+
predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
4150
}

mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import org.apache.spark.api.java.JavaRDD;
2121
import org.apache.spark.api.java.JavaSparkContext;
22+
import org.apache.spark.api.java.function.Function;
23+
import org.apache.spark.mllib.linalg.Vector;
2224
import org.apache.spark.mllib.linalg.Vectors;
2325
import org.apache.spark.mllib.regression.LabeledPoint;
2426
import org.junit.After;
@@ -87,4 +89,18 @@ public void runUsingStaticMethods() {
8789
int numAccurate2 = validatePrediction(POINTS, model2);
8890
Assert.assertEquals(POINTS.size(), numAccurate2);
8991
}
92+
93+
@Test
94+
public void testPredictJavaRDD() {
95+
JavaRDD<LabeledPoint> examples = sc.parallelize(POINTS, 2).cache();
96+
NaiveBayesModel model = NaiveBayes.train(examples.rdd());
97+
JavaRDD<Vector> vectors = examples.map(new Function<LabeledPoint, Vector>() {
98+
@Override
99+
public Vector call(LabeledPoint v) throws Exception {
100+
return v.features();
101+
}});
102+
JavaRDD<Double> predictions = model.predict(vectors);
103+
// Should be able to get the first prediction.
104+
predictions.first();
105+
}
90106
}

mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,18 @@ public void runKMeansUsingConstructor() {
8888
.run(data.rdd());
8989
assertEquals(expectedCenter, model.clusterCenters()[0]);
9090
}
91+
92+
@Test
93+
public void testPredictJavaRDD() {
94+
List<Vector> points = Lists.newArrayList(
95+
Vectors.dense(1.0, 2.0, 6.0),
96+
Vectors.dense(1.0, 3.0, 0.0),
97+
Vectors.dense(1.0, 4.0, 6.0)
98+
);
99+
JavaRDD<Vector> data = sc.parallelize(points, 2);
100+
KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
101+
JavaRDD<Integer> predictions = model.predict(data);
102+
// Should be able to get the first prediction.
103+
predictions.first();
104+
}
91105
}

mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
import org.junit.Before;
2626
import org.junit.Test;
2727

28+
import org.apache.spark.api.java.function.Function;
2829
import org.apache.spark.api.java.JavaRDD;
2930
import org.apache.spark.api.java.JavaSparkContext;
31+
import org.apache.spark.mllib.linalg.Vector;
3032
import org.apache.spark.mllib.util.LinearDataGenerator;
3133

3234
public class JavaLinearRegressionSuite implements Serializable {
@@ -92,4 +94,23 @@ public void runLinearRegressionUsingStaticMethods() {
9294
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
9395
}
9496

97+
@Test
98+
public void testPredictJavaRDD() {
99+
int nPoints = 100;
100+
double A = 0.0;
101+
double[] weights = {10, 10};
102+
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
103+
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
104+
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
105+
LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
106+
JavaRDD<Vector> vectors = testRDD.map(new Function<LabeledPoint, Vector>() {
107+
@Override
108+
public Vector call(LabeledPoint v) throws Exception {
109+
return v.features();
110+
}
111+
});
112+
JavaRDD<Double> predictions = model.predict(vectors);
113+
// Should be able to get the first prediction.
114+
predictions.first();
115+
}
95116
}

0 commit comments

Comments
 (0)