Skip to content

Commit 8d13233

Browse files
committed
Added methods:
* Classifier: batch predictRaw() * Predictor: train() without paramMap ProbabilisticClassificationModel.predictProbabilities() * Java versions of all above batch methods + others Updated LogisticRegressionSuite. Updated JavaLogisticRegressionSuite to match LogisticRegressionSuite.
1 parent 1680905 commit 8d13233

File tree

7 files changed

+174
-22
lines changed

7 files changed

+174
-22
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
package org.apache.spark.ml.classification
1919

2020
import org.apache.spark.annotation.AlphaComponent
21+
import org.apache.spark.api.java.JavaRDD
2122
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
2223
import org.apache.spark.mllib.linalg.Vector
24+
import org.apache.spark.rdd.RDD
2325

2426
/**
2527
* Params for classification.
@@ -72,6 +74,14 @@ abstract class ClassificationModel[M <: ClassificationModel[M]]
7274
*/
7375
def predictRaw(features: Vector): Vector
7476

77+
/** Batch version of [[predictRaw]] */
78+
def predictRaw(dataset: RDD[Vector]): RDD[Vector] = dataset.map(predictRaw)
79+
80+
/** Java-friendly batch version of [[predictRaw]] */
81+
def predictRaw(dataset: JavaRDD[Vector]): JavaRDD[Vector] = {
82+
dataset.rdd.map(predictRaw).toJavaRDD()
83+
}
84+
7585
// TODO: accuracy(dataset: RDD[LabeledPoint]): Double (follow-up PR)
7686

7787
}

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressi
102102
* NOTE: This does NOT support instance weights.
103103
* @param dataset Training data. Instance weights are ignored.
104104
*/
105-
def train(dataset: RDD[LabeledPoint]): LogisticRegressionModel = train(dataset, new ParamMap())
105+
override def train(dataset: RDD[LabeledPoint]): LogisticRegressionModel =
106+
train(dataset, new ParamMap()) // Override documentation
106107
}
107108

108109

mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.ml.impl.estimator
1919

20+
import org.apache.spark.api.java.JavaRDD
2021
import org.apache.spark.ml.{Estimator, LabeledPoint, Model}
2122
import org.apache.spark.ml.param._
2223
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
@@ -101,6 +102,18 @@ private[ml] abstract class Predictor[Learner <: Predictor[Learner, M], M <: Pred
101102
* These values override any specified in this Estimator's embedded ParamMap.
102103
*/
103104
def train(dataset: RDD[LabeledPoint], paramMap: ParamMap): M
105+
106+
/**
107+
* Same as [[fit()]], but using strong types.
108+
* @param dataset Training data
109+
*/
110+
def train(dataset: RDD[LabeledPoint]): M = train(dataset, new ParamMap())
111+
112+
/** Java-friendly version of [[train()]]. */
113+
def train(dataset: JavaRDD[LabeledPoint], paramMap: ParamMap): M = train(dataset.rdd, paramMap)
114+
115+
/** Java-friendly version of [[train()]]. */
116+
def train(dataset: JavaRDD[LabeledPoint]): M = train(dataset.rdd)
104117
}
105118

106119
private[ml] abstract class PredictionModel[M <: PredictionModel[M]]
@@ -156,6 +169,16 @@ private[ml] abstract class PredictionModel[M <: PredictionModel[M]]
156169
*/
157170
def predict(features: Vector): Double
158171

172+
/** Java-friendly version of [[predict()]]. */
173+
def predict(dataset: JavaRDD[Vector], paramMap: ParamMap): JavaRDD[java.lang.Double] = {
174+
predict(dataset.rdd, paramMap).map(_.asInstanceOf[java.lang.Double]).toJavaRDD()
175+
}
176+
177+
/** Java-friendly version of [[predict()]]. */
178+
def predict(dataset: JavaRDD[Vector]): JavaRDD[java.lang.Double] = {
179+
predict(dataset.rdd, new ParamMap).map(_.asInstanceOf[java.lang.Double]).toJavaRDD()
180+
}
181+
159182
/**
160183
* Create a copy of the model.
161184
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.

mllib/src/main/scala/org/apache/spark/ml/impl/estimator/ProbabilisticClassificationModel.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
package org.apache.spark.ml.impl.estimator
1919

20+
import org.apache.spark.api.java.JavaRDD
2021
import org.apache.spark.mllib.linalg.Vector
22+
import org.apache.spark.rdd.RDD
2123

2224
/**
2325
* Trait for a [[org.apache.spark.ml.classification.ClassificationModel]] which can output
@@ -34,4 +36,11 @@ private[ml] trait ProbabilisticClassificationModel {
3436
*/
3537
def predictProbabilities(features: Vector): Vector
3638

39+
/** Batch version of [[predictProbabilities()]] */
40+
def predictProbabilities(features: RDD[Vector]): RDD[Vector] = features.map(predictProbabilities)
41+
42+
/** Java-friendly batch version of [[predictProbabilities()]] */
43+
def predictProbabilities(features: JavaRDD[Vector]): JavaRDD[Vector] = {
44+
features.rdd.map(predictProbabilities).toJavaRDD()
45+
}
3746
}

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ class LinearRegression extends Regressor[LinearRegression, LinearRegressionModel
7878
* NOTE: This does NOT support instance weights.
7979
* @param dataset Training data. Instance weights are ignored.
8080
*/
81-
def train(dataset: RDD[LabeledPoint]): LinearRegressionModel = train(dataset, new ParamMap())
81+
override def train(dataset: RDD[LabeledPoint]): LinearRegressionModel =
82+
train(dataset, new ParamMap()) // Override documentation
8283
}
8384

8485
/**

mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java

Lines changed: 113 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,54 @@
1717

1818
package org.apache.spark.ml.classification;
1919

20+
import scala.Tuple2;
21+
2022
import java.io.Serializable;
23+
import java.lang.Math;
24+
import java.util.ArrayList;
2125
import java.util.List;
2226

2327
import org.junit.After;
2428
import org.junit.Before;
2529
import org.junit.Test;
2630

31+
import org.apache.spark.api.java.JavaRDD;
2732
import org.apache.spark.api.java.JavaSparkContext;
2833
import org.apache.spark.mllib.regression.LabeledPoint;
2934
import org.apache.spark.sql.DataFrame;
3035
import org.apache.spark.sql.SQLContext;
3136
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
37+
import org.apache.spark.api.java.function.Function;
38+
import org.apache.spark.mllib.linalg.Vector;
39+
import org.apache.spark.ml.LabeledPoint;
40+
import org.apache.spark.sql.Row;
41+
3242

3343
public class JavaLogisticRegressionSuite implements Serializable {
3444

3545
private transient JavaSparkContext jsc;
3646
private transient SQLContext jsql;
3747
private transient DataFrame dataset;
3848

49+
private transient JavaRDD<LabeledPoint> datasetRDD;
50+
private transient JavaRDD<Vector> featuresRDD;
51+
private double eps = 1e-5;
52+
3953
@Before
4054
public void setUp() {
4155
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
4256
jsql = new SQLContext(jsc);
43-
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
44-
dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
57+
List<LabeledPoint> points = new ArrayList<LabeledPoint>();
58+
for (org.apache.spark.mllib.regression.LabeledPoint lp:
59+
generateLogisticInputAsList(1.0, 1.0, 100, 42)) {
60+
points.add(new LabeledPoint(lp.label(), lp.features()));
61+
}
62+
datasetRDD = jsc.parallelize(points, 2);
63+
featuresRDD = datasetRDD.map(new Function<LabeledPoint, Vector>() {
64+
@Override public Vector call(LabeledPoint lp) { return lp.features(); }
65+
});
66+
dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
67+
dataset.registerTempTable("dataset");
4568
}
4669

4770
@After
@@ -51,29 +74,112 @@ public void tearDown() {
5174
}
5275

5376
@Test
54-
public void logisticRegression() {
77+
public void logisticRegressionDefaultParams() {
5578
LogisticRegression lr = new LogisticRegression();
79+
assert(lr.getLabelCol().equals("label"));
5680
LogisticRegressionModel model = lr.fit(dataset);
5781
model.transform(dataset).registerTempTable("prediction");
5882
DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
5983
predictions.collectAsList();
84+
// Check defaults
85+
assert(model.getThreshold() == 0.5);
86+
assert(model.getFeaturesCol().equals("features"));
87+
assert(model.getPredictionCol().equals("prediction"));
88+
assert(model.getScoreCol().equals("score"));
6089
}
6190

6291
@Test
6392
public void logisticRegressionWithSetters() {
93+
// Set params, train, and check as many params as we can.
6494
LogisticRegression lr = new LogisticRegression()
6595
.setMaxIter(10)
66-
.setRegParam(1.0);
96+
.setRegParam(1.0)
97+
.setThreshold(0.6)
98+
.setScoreCol("probability");
6799
LogisticRegressionModel model = lr.fit(dataset);
100+
assert(model.fittingParamMap().get(lr.maxIter()).get() == 10);
101+
assert(model.fittingParamMap().get(lr.regParam()).get() == 1.0);
102+
assert(model.fittingParamMap().get(lr.threshold()).get() == 0.6);
103+
assert(model.getThreshold() == 0.6);
104+
105+
// Modify model params, and check that the params worked.
106+
model.setThreshold(1.0);
107+
model.transform(dataset).registerTempTable("predAllZero");
108+
SchemaRDD predAllZero = jsql.sql("SELECT prediction, probability FROM predAllZero");
109+
for (Row r: predAllZero.collectAsList()) {
110+
assert(r.getDouble(0) == 0.0);
111+
}
112+
// Call transform with params, and check that the params worked.
113+
/* TODO: USE THIS
68114
model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
69-
.registerTempTable("prediction");
115+
.registerTempTable("prediction");
70116
DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
71117
predictions.collectAsList();
118+
*/
119+
120+
model.transform(dataset, model.threshold().w(0.0), model.scoreCol().w("myProb"))
121+
.registerTempTable("predNotAllZero");
122+
SchemaRDD predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
123+
boolean foundNonZero = false;
124+
for (Row r: predNotAllZero.collectAsList()) {
125+
if (r.getDouble(0) != 0.0) foundNonZero = true;
126+
}
127+
assert(foundNonZero);
128+
129+
// Call fit() with new params, and check as many params as we can.
130+
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
131+
lr.threshold().w(0.4), lr.scoreCol().w("theProb"));
132+
assert(model2.fittingParamMap().get(lr.maxIter()).get() == 5);
133+
assert(model2.fittingParamMap().get(lr.regParam()).get() == 0.1);
134+
assert(model2.fittingParamMap().get(lr.threshold()).get() == 0.4);
135+
assert(model2.getThreshold() == 0.4);
136+
assert(model2.getScoreCol().equals("theProb"));
72137
}
73138

74139
@Test
75-
public void logisticRegressionFitWithVarargs() {
140+
public void logisticRegressionPredictorClassifierMethods() {
76141
LogisticRegression lr = new LogisticRegression();
77-
lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0));
142+
143+
// fit() vs. train()
144+
LogisticRegressionModel model1 = lr.fit(dataset);
145+
LogisticRegressionModel model2 = lr.train(datasetRDD);
146+
assert(model1.intercept() == model2.intercept());
147+
assert(model1.weights().equals(model2.weights()));
148+
assert(model1.numClasses() == model2.numClasses());
149+
assert(model1.numClasses() == 2);
150+
151+
// transform() vs. predict()
152+
model1.transform(dataset).registerTempTable("transformed");
153+
SchemaRDD trans = jsql.sql("SELECT prediction FROM transformed");
154+
JavaRDD<Double> preds = model1.predict(featuresRDD);
155+
for (scala.Tuple2<Row, Double> trans_pred: trans.toJavaRDD().zip(preds).collect()) {
156+
double t = trans_pred._1().getDouble(0);
157+
double p = trans_pred._2();
158+
assert(t == p);
159+
}
160+
161+
// Check various types of predictions.
162+
JavaRDD<Vector> rawPredictions = model1.predictRaw(featuresRDD);
163+
JavaRDD<Vector> probabilities = model1.predictProbabilities(featuresRDD);
164+
JavaRDD<Double> predictions = model1.predict(featuresRDD);
165+
double threshold = model1.getThreshold();
166+
for (Tuple2<Vector, Vector> raw_prob: rawPredictions.zip(probabilities).collect()) {
167+
Vector raw = raw_prob._1();
168+
Vector prob = raw_prob._2();
169+
for (int i = 0; i < raw.size(); ++i) {
170+
double r = raw.apply(i);
171+
double p = prob.apply(i);
172+
double pFromR = 1.0 / (1.0 + Math.exp(-r));
173+
assert(Math.abs(r - pFromR) < eps);
174+
}
175+
}
176+
for (Tuple2<Vector, Double> prob_pred: probabilities.zip(predictions).collect()) {
177+
Vector prob = prob_pred._1();
178+
double pred = prob_pred._2();
179+
double probOfPred = prob.apply((int)pred);
180+
for (int i = 0; i < prob.size(); ++i) {
181+
assert(probOfPred >= prob.apply(i));
182+
}
183+
}
78184
}
79185
}

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
4545
assert(lr.getLabelCol == "label")
4646
val model = lr.fit(dataset)
4747
model.transform(dataset)
48-
.select("label", "prediction")
48+
.select('label, 'score, 'prediction)
4949
.collect()
5050
// Check defaults
5151
assert(model.getThreshold === 0.5)
@@ -55,7 +55,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
5555
}
5656

5757
test("logistic regression with setters") {
58-
// Set params, train, and check as many as we can.
58+
// Set params, train, and check as many params as we can.
5959
val lr = new LogisticRegression()
6060
.setMaxIter(10)
6161
.setRegParam(1.0)
@@ -77,27 +77,27 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
7777
assert(model.fittingParamMap.get(lr.threshold) === Some(0.6))
7878
assert(model.getThreshold === 0.6)
7979

80-
// Modify model params, and check that they work.
80+
// Modify model params, and check that the params worked.
8181
model.setThreshold(1.0)
8282
val predAllZero = model.transform(dataset)
8383
.select('prediction, 'probability)
8484
.collect()
8585
.map { case Row(pred: Double, prob: Double) => pred }
8686
assert(predAllZero.forall(_ === 0.0))
87-
// Call transform with params, and check that they work.
87+
// Call transform with params, and check that the params worked.
8888
val predNotAllZero =
8989
model.transform(dataset, model.threshold -> 0.0, model.scoreCol -> "myProb")
9090
.select('prediction, 'myProb)
9191
.collect()
9292
.map { case Row(pred: Double, prob: Double) => pred }
9393
assert(predNotAllZero.exists(_ !== 0.0))
9494

95-
// Call fit() with new params, and check as many as we can.
95+
// Call fit() with new params, and check as many params as we can.
9696
val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4,
9797
lr.scoreCol -> "theProb")
98-
assert(model2.fittingParamMap.get(lr.maxIter) === Some(5))
99-
assert(model2.fittingParamMap.get(lr.regParam) === Some(0.1))
100-
assert(model2.fittingParamMap.get(lr.threshold) === Some(0.4))
98+
assert(model2.fittingParamMap.get(lr.maxIter).get === 5)
99+
assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
100+
assert(model2.fittingParamMap.get(lr.threshold).get === 0.4)
101101
assert(model2.getThreshold === 0.4)
102102
assert(model2.getScoreCol == "theProb")
103103
}
@@ -112,7 +112,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
112112
val rdd = dataset.select('label, 'features).map { case Row(label: Double, features: Vector) =>
113113
LabeledPoint(label, features)
114114
}
115-
val features = rdd.map(_.features)
115+
val featuresRDD = rdd.map(_.features)
116116
val model2 = lr.train(rdd)
117117
assert(model1.intercept == model2.intercept)
118118
assert(model1.weights.equals(model2.weights))
@@ -127,15 +127,17 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
127127
}
128128

129129
// Check various types of predictions.
130-
val allPredictions = features.map { f =>
131-
(model1.predictRaw(f), model1.predictProbabilities(f), model1.predict(f))
132-
}.collect()
130+
val rawPredictions = model1.predictRaw(featuresRDD)
131+
val probabilities = model1.predictProbabilities(featuresRDD)
132+
val predictions = model1.predict(featuresRDD)
133133
val threshold = model1.getThreshold
134-
allPredictions.foreach { case (raw: Vector, prob: Vector, pred: Double) =>
134+
rawPredictions.zip(probabilities).collect().foreach { case (raw: Vector, prob: Vector) =>
135135
val computeProbFromRaw: (Double => Double) = (m) => 1.0 / (1.0 + math.exp(-m))
136136
raw.toArray.map(computeProbFromRaw).zip(prob.toArray).foreach { case (r, p) =>
137137
assert(r ~== p relTol eps)
138138
}
139+
}
140+
probabilities.zip(predictions).collect().foreach { case (prob: Vector, pred: Double) =>
139141
val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
140142
assert(pred == predFromProb)
141143
}

0 commit comments

Comments
 (0)