Skip to content

Commit dc0c449

Browse files
jkbradleymengxr
authored andcommitted
[SPARK-4789] [SPARK-4942] [SPARK-5031] [mllib] Standardize ML Prediction APIs
This is part (1a) of the updates from the design doc in [https://docs.google.com/document/d/1BH9el33kBX8JiDdgUJXdLW14CA2qhTCWIG46eXZVoJs] **UPDATE**: Most of the APIs are being kept private[spark] to allow further discussion. Here is a list of changes which are public: * new output columns: rawPrediction, probabilities * The “score” column is now called “rawPrediction” * Classifiers now provide numClasses * Params.get and .set are now protected instead of private[ml]. * ParamMap now has a size method. * new classes: LinearRegression, LinearRegressionModel * LogisticRegression now has an intercept. ### Sketch of APIs (most of which are private[spark] for now) Abstract classes for learning algorithms (+ corresponding Model abstractions): * Classifier (+ ClassificationModel) * ProbabilisticClassifier (+ ProbabilisticClassificationModel) * Regressor (+ RegressionModel) * Predictor (+ PredictionModel) * *For all of these*: * There is no strongly typed training-time API. * There is a strongly typed test-time (prediction) API which helps developers implement new algorithms. Concrete classes: learning algorithms * LinearRegression * LogisticRegression (updated to use new abstract classes) * Also, removed "score" in favor of "probability" output column. Changed BinaryClassificationEvaluator to match. (SPARK-5031) Other updates: * params.scala: Changed Params.set/get to be protected instead of private[ml] * This was needed for the example of defining a class from outside of the MLlib namespace. * VectorUDT: Will later change from private[spark] to public. * This is needed for outside users to write their own validateAndTransformSchema() methods using vectors. * Also, added equals() method.f * SPARK-4942 : ML Transformers should allow output cols to be turned on,off * Update validateAndTransformSchema * Update transform * (Updated examples, test suites according to other changes) New examples: * DeveloperApiExample.scala (example of defining algorithm from outside of the MLlib namespace) * Added Java version too Test Suites: * LinearRegressionSuite * LogisticRegressionSuite * + Java versions of above suites CC: mengxr etrain shivaram Author: Joseph K. Bradley <joseph@databricks.com> Closes apache#3637 from jkbradley/ml-api-part1 and squashes the following commits: 405bfb8 [Joseph K. Bradley] Last edits based on code review. Small cleanups fec348a [Joseph K. Bradley] Added JavaDeveloperApiExample.java and fixed other issues: Made developer API private[spark] for now. Added constructors Java can understand to specialized Param types. 8316d5e [Joseph K. Bradley] fixes after rebasing on master fc62406 [Joseph K. Bradley] fixed test suites after last commit bcb9549 [Joseph K. Bradley] Fixed issues after rebasing from master (after move from SchemaRDD to DataFrame) 9872424 [Joseph K. Bradley] fixed JavaLinearRegressionSuite.java Java sql api f542997 [Joseph K. Bradley] Added MIMA excludes for VectorUDT (now public), and added DeveloperApi annotation to it 216d199 [Joseph K. Bradley] fixed after sql datatypes PR got merged f549e34 [Joseph K. Bradley] Updates based on code review. Major ones are: * Created weakly typed Predictor.train() method which is called by fit() so that developers do not have to call schema validation or copy parameters. * Made Predictor.featuresDataType have a default value of VectorUDT. * NOTE: This could be dangerous since the FeaturesType type parameter cannot have a default value. 343e7bd [Joseph K. Bradley] added blanket mima exclude for ml package 82f340b [Joseph K. Bradley] Fixed bug in LogisticRegression (introduced in this PR). Fixed Java suites 0a16da9 [Joseph K. Bradley] Fixed Linear/Logistic RegressionSuites c3c8da5 [Joseph K. Bradley] small cleanup 934f97b [Joseph K. Bradley] Fixed bugs from previous commit. 1c61723 [Joseph K. Bradley] * Made ProbabilisticClassificationModel into a subclass of ClassificationModel. Also introduced ProbabilisticClassifier. * This was to support output column “probabilityCol” in transform(). 4e2f711 [Joseph K. Bradley] rat fix bc654e1 [Joseph K. Bradley] Added spark.ml LinearRegressionSuite 8d13233 [Joseph K. Bradley] Added methods: * Classifier: batch predictRaw() * Predictor: train() without paramMap ProbabilisticClassificationModel.predictProbabilities() * Java versions of all above batch methods + others 1680905 [Joseph K. Bradley] Added JavaLabeledPointSuite.java for spark.ml, and added constructor to LabeledPoint which defaults weight to 1.0 adbe50a [Joseph K. Bradley] * fixed LinearRegression train() to use embedded paramMap * added Predictor.predict(RDD[Vector]) method * updated Linear/LogisticRegressionSuites 58802e3 [Joseph K. Bradley] added train() to Predictor subclasses which does not take a ParamMap. 57d54ab [Joseph K. Bradley] * Changed semantics of Predictor.train() to merge the given paramMap with the embedded paramMap. * remove threshold_internal from logreg * Added Predictor.copy() * Extended LogisticRegressionSuite e433872 [Joseph K. Bradley] Updated docs. Added LabeledPointSuite to spark.ml 54b7b31 [Joseph K. Bradley] Fixed issue with logreg threshold being set correctly 0617d61 [Joseph K. Bradley] Fixed bug from last commit (sorting paramMap by parameter names in toString). Fixed bug in persisting logreg data. Added threshold_internal to logreg for faster test-time prediction (avoiding map lookup). 601e792 [Joseph K. Bradley] Modified ParamMap to sort parameters in toString. Cleaned up classes in class hierarchy, before implementing tests and examples. d705e87 [Joseph K. Bradley] Added LinearRegression and Regressor back from ml-api branch 52f4fde [Joseph K. Bradley] removing everything except for simple class hierarchy for classification d35bb5d [Joseph K. Bradley] fixed compilation issues, but have not added tests yet bfade12 [Joseph K. Bradley] Added lots of classes for new ML API:
1 parent 6b88825 commit dc0c449

26 files changed

+1753
-156
lines changed

examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,12 @@ public static void main(String[] args) {
116116

117117
// Make predictions on test documents. cvModel uses the best model found (lrModel).
118118
cvModel.transform(test).registerTempTable("prediction");
119-
DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
119+
DataFrame predictions = jsql.sql("SELECT id, text, probability, prediction FROM prediction");
120120
for (Row r: predictions.collect()) {
121-
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
121+
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
122122
+ ", prediction=" + r.get(3));
123123
}
124+
125+
jsc.stop();
124126
}
125127
}
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.ml;
19+
20+
import java.util.List;
21+
22+
import com.google.common.collect.Lists;
23+
24+
import org.apache.spark.SparkConf;
25+
import org.apache.spark.api.java.JavaRDD;
26+
import org.apache.spark.api.java.JavaSparkContext;
27+
import org.apache.spark.ml.classification.Classifier;
28+
import org.apache.spark.ml.classification.ClassificationModel;
29+
import org.apache.spark.ml.param.IntParam;
30+
import org.apache.spark.ml.param.ParamMap;
31+
import org.apache.spark.ml.param.Params;
32+
import org.apache.spark.ml.param.Params$;
33+
import org.apache.spark.mllib.linalg.BLAS;
34+
import org.apache.spark.mllib.linalg.Vector;
35+
import org.apache.spark.mllib.linalg.Vectors;
36+
import org.apache.spark.mllib.regression.LabeledPoint;
37+
import org.apache.spark.sql.DataFrame;
38+
import org.apache.spark.sql.Row;
39+
import org.apache.spark.sql.SQLContext;
40+
41+
42+
/**
43+
* A simple example demonstrating how to write your own learning algorithm using Estimator,
44+
* Transformer, and other abstractions.
45+
* This mimics {@link org.apache.spark.ml.classification.LogisticRegression}.
46+
*
47+
* Run with
48+
* <pre>
49+
* bin/run-example ml.JavaDeveloperApiExample
50+
* </pre>
51+
*/
52+
public class JavaDeveloperApiExample {
53+
54+
public static void main(String[] args) throws Exception {
55+
SparkConf conf = new SparkConf().setAppName("JavaDeveloperApiExample");
56+
JavaSparkContext jsc = new JavaSparkContext(conf);
57+
SQLContext jsql = new SQLContext(jsc);
58+
59+
// Prepare training data.
60+
List<LabeledPoint> localTraining = Lists.newArrayList(
61+
new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
62+
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
63+
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
64+
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
65+
DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
66+
67+
// Create a LogisticRegression instance. This instance is an Estimator.
68+
MyJavaLogisticRegression lr = new MyJavaLogisticRegression();
69+
// Print out the parameters, documentation, and any default values.
70+
System.out.println("MyJavaLogisticRegression parameters:\n" + lr.explainParams() + "\n");
71+
72+
// We may set parameters using setter methods.
73+
lr.setMaxIter(10);
74+
75+
// Learn a LogisticRegression model. This uses the parameters stored in lr.
76+
MyJavaLogisticRegressionModel model = lr.fit(training);
77+
78+
// Prepare test data.
79+
List<LabeledPoint> localTest = Lists.newArrayList(
80+
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
81+
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
82+
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
83+
DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
84+
85+
// Make predictions on test documents. cvModel uses the best model found (lrModel).
86+
DataFrame results = model.transform(test);
87+
double sumPredictions = 0;
88+
for (Row r : results.select("features", "label", "prediction").collect()) {
89+
sumPredictions += r.getDouble(2);
90+
}
91+
if (sumPredictions != 0.0) {
92+
throw new Exception("MyJavaLogisticRegression predicted something other than 0," +
93+
" even though all weights are 0!");
94+
}
95+
96+
jsc.stop();
97+
}
98+
}
99+
100+
/**
101+
* Example of defining a type of {@link Classifier}.
102+
*
103+
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
104+
*/
105+
class MyJavaLogisticRegression
106+
extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel>
107+
implements Params {
108+
109+
/**
110+
* Param for max number of iterations
111+
* <p/>
112+
* NOTE: The usual way to add a parameter to a model or algorithm is to include:
113+
* - val myParamName: ParamType
114+
* - def getMyParamName
115+
* - def setMyParamName
116+
*/
117+
IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations");
118+
119+
int getMaxIter() { return (int)get(maxIter); }
120+
121+
public MyJavaLogisticRegression() {
122+
setMaxIter(100);
123+
}
124+
125+
// The parameter setter is in this class since it should return type MyJavaLogisticRegression.
126+
MyJavaLogisticRegression setMaxIter(int value) {
127+
return (MyJavaLogisticRegression)set(maxIter, value);
128+
}
129+
130+
// This method is used by fit().
131+
// In Java, we have to make it public since Java does not understand Scala's protected modifier.
132+
public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap) {
133+
// Extract columns from data using helper method.
134+
JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset, paramMap).toJavaRDD();
135+
136+
// Do learning to estimate the weight vector.
137+
int numFeatures = oldDataset.take(1).get(0).features().size();
138+
Vector weights = Vectors.zeros(numFeatures); // Learning would happen here.
139+
140+
// Create a model, and return it.
141+
return new MyJavaLogisticRegressionModel(this, paramMap, weights);
142+
}
143+
}
144+
145+
/**
146+
* Example of defining a type of {@link ClassificationModel}.
147+
*
148+
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
149+
*/
150+
class MyJavaLogisticRegressionModel
151+
extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> implements Params {
152+
153+
private MyJavaLogisticRegression parent_;
154+
public MyJavaLogisticRegression parent() { return parent_; }
155+
156+
private ParamMap fittingParamMap_;
157+
public ParamMap fittingParamMap() { return fittingParamMap_; }
158+
159+
private Vector weights_;
160+
public Vector weights() { return weights_; }
161+
162+
public MyJavaLogisticRegressionModel(
163+
MyJavaLogisticRegression parent_,
164+
ParamMap fittingParamMap_,
165+
Vector weights_) {
166+
this.parent_ = parent_;
167+
this.fittingParamMap_ = fittingParamMap_;
168+
this.weights_ = weights_;
169+
}
170+
171+
// This uses the default implementation of transform(), which reads column "features" and outputs
172+
// columns "prediction" and "rawPrediction."
173+
174+
// This uses the default implementation of predict(), which chooses the label corresponding to
175+
// the maximum value returned by [[predictRaw()]].
176+
177+
/**
178+
* Raw prediction for each possible label.
179+
* The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives
180+
* a measure of confidence in each possible label (where larger = more confident).
181+
* This internal method is used to implement [[transform()]] and output [[rawPredictionCol]].
182+
*
183+
* @return vector where element i is the raw prediction for label i.
184+
* This raw prediction may be any real number, where a larger value indicates greater
185+
* confidence for that label.
186+
*
187+
* In Java, we have to make this method public since Java does not understand Scala's protected
188+
* modifier.
189+
*/
190+
public Vector predictRaw(Vector features) {
191+
double margin = BLAS.dot(features, weights_);
192+
// There are 2 classes (binary classification), so we return a length-2 vector,
193+
// where index i corresponds to class i (i = 0, 1).
194+
return Vectors.dense(-margin, margin);
195+
}
196+
197+
/**
198+
* Number of classes the label can take. 2 indicates binary classification.
199+
*/
200+
public int numClasses() { return 2; }
201+
202+
/**
203+
* Create a copy of the model.
204+
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.
205+
* <p/>
206+
* This is used for the defaul implementation of [[transform()]].
207+
*
208+
* In Java, we have to make this method public since Java does not understand Scala's protected
209+
* modifier.
210+
*/
211+
public MyJavaLogisticRegressionModel copy() {
212+
MyJavaLogisticRegressionModel m =
213+
new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_);
214+
Params$.MODULE$.inheritValues(this.paramMap(), this, m);
215+
return m;
216+
}
217+
}

examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ public static void main(String[] args) {
8181

8282
// One can also combine ParamMaps.
8383
ParamMap paramMap2 = new ParamMap();
84-
paramMap2.put(lr.scoreCol().w("probability")); // Change output column name
84+
paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name
8585
ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);
8686

8787
// Now learn a new model using the paramMapCombined parameters.
@@ -98,14 +98,16 @@ public static void main(String[] args) {
9898

9999
// Make predictions on test documents using the Transformer.transform() method.
100100
// LogisticRegression.transform will only use the 'features' column.
101-
// Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
102-
// column since we renamed the lr.scoreCol parameter previously.
101+
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
102+
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
103103
model2.transform(test).registerTempTable("results");
104104
DataFrame results =
105-
jsql.sql("SELECT features, label, probability, prediction FROM results");
105+
jsql.sql("SELECT features, label, myProbability, prediction FROM results");
106106
for (Row r: results.collect()) {
107107
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
108108
+ ", prediction=" + r.get(3));
109109
}
110+
111+
jsc.stop();
110112
}
111113
}

examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,10 @@ public static void main(String[] args) {
8585
model.transform(test).registerTempTable("prediction");
8686
DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
8787
for (Row r: predictions.collect()) {
88-
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
88+
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
8989
+ ", prediction=" + r.get(3));
9090
}
91+
92+
jsc.stop();
9193
}
9294
}

examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.ml.classification.LogisticRegression
2323
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
2424
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
2525
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
26+
import org.apache.spark.mllib.linalg.Vector
2627
import org.apache.spark.sql.{Row, SQLContext}
2728

2829
/**
@@ -100,10 +101,10 @@ object CrossValidatorExample {
100101

101102
// Make predictions on test documents. cvModel uses the best model found (lrModel).
102103
cvModel.transform(test)
103-
.select("id", "text", "score", "prediction")
104+
.select("id", "text", "probability", "prediction")
104105
.collect()
105-
.foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
106-
println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
106+
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
107+
println(s"($id, $text) --> prob=$prob, prediction=$prediction")
107108
}
108109

109110
sc.stop()

0 commit comments

Comments
 (0)