Skip to content

Commit fec348a

Browse files
committed
Added JavaDeveloperApiExample.java and fixed other issues: Made developer API private[spark] for now. Added constructors Java can understand to specialized Param types.
1 parent 8316d5e commit fec348a

File tree

10 files changed

+343
-16
lines changed

10 files changed

+343
-16
lines changed

examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,5 +121,7 @@ public static void main(String[] args) {
121121
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
122122
+ ", prediction=" + r.get(3));
123123
}
124+
125+
jsc.stop();
124126
}
125127
}
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.ml;
19+
20+
import java.util.List;
21+
22+
import com.google.common.collect.Lists;
23+
24+
import org.apache.spark.SparkConf;
25+
import org.apache.spark.api.java.JavaRDD;
26+
import org.apache.spark.api.java.JavaSparkContext;
27+
import org.apache.spark.ml.classification.Classifier;
28+
import org.apache.spark.ml.classification.ClassificationModel;
29+
import org.apache.spark.ml.param.IntParam;
30+
import org.apache.spark.ml.param.ParamMap;
31+
import org.apache.spark.ml.param.Params;
32+
import org.apache.spark.ml.param.Params$;
33+
import org.apache.spark.mllib.linalg.BLAS;
34+
import org.apache.spark.mllib.linalg.Vector;
35+
import org.apache.spark.mllib.linalg.Vectors;
36+
import org.apache.spark.mllib.regression.LabeledPoint;
37+
import org.apache.spark.sql.DataFrame;
38+
import org.apache.spark.sql.Row;
39+
import org.apache.spark.sql.SQLContext;
40+
41+
42+
/**
43+
* A simple example demonstrating how to write your own learning algorithm using Estimator,
44+
* Transformer, and other abstractions.
45+
* This mimics {@link org.apache.spark.ml.classification.LogisticRegression}.
46+
*
47+
* Run with
48+
* <pre>
49+
* bin/run-example ml.JavaDeveloperApiExample
50+
* </pre>
51+
*/
52+
public class JavaDeveloperApiExample {
53+
54+
public static void main(String[] args) throws Exception {
55+
SparkConf conf = new SparkConf().setAppName("JavaDeveloperApiExample");
56+
JavaSparkContext jsc = new JavaSparkContext(conf);
57+
SQLContext jsql = new SQLContext(jsc);
58+
59+
// Prepare training data.
60+
List<LabeledPoint> localTraining = Lists.newArrayList(
61+
new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
62+
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
63+
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
64+
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
65+
DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
66+
67+
// Create a LogisticRegression instance. This instance is an Estimator.
68+
MyJavaLogisticRegression lr = new MyJavaLogisticRegression();
69+
// Print out the parameters, documentation, and any default values.
70+
System.out.println("MyJavaLogisticRegression parameters:\n" + lr.explainParams() + "\n");
71+
72+
// We may set parameters using setter methods.
73+
lr.setMaxIter(10);
74+
75+
// Learn a LogisticRegression model. This uses the parameters stored in lr.
76+
MyJavaLogisticRegressionModel model = lr.fit(training);
77+
78+
// Prepare test data.
79+
List<LabeledPoint> localTest = Lists.newArrayList(
80+
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
81+
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
82+
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
83+
DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
84+
85+
// Make predictions on test documents. cvModel uses the best model found (lrModel).
86+
DataFrame results = model.transform(test);
87+
double sumPredictions = 0;
88+
for (Row r : results.select("features", "label", "prediction").collect()) {
89+
sumPredictions += r.getDouble(2);
90+
}
91+
if (sumPredictions != 0.0) {
92+
throw new Exception("MyJavaLogisticRegression predicted something other than 0," +
93+
" even though all weights are 0!");
94+
}
95+
96+
jsc.stop();
97+
}
98+
}
99+
100+
/**
101+
* Example of defining a type of {@link Classifier}.
102+
*
103+
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
104+
*/
105+
class MyJavaLogisticRegression
106+
extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel>
107+
implements Params {
108+
109+
/**
110+
* Param for max number of iterations
111+
* <p/>
112+
* NOTE: The usual way to add a parameter to a model or algorithm is to include:
113+
* - val myParamName: ParamType
114+
* - def getMyParamName
115+
* - def setMyParamName
116+
*/
117+
IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations");
118+
119+
int getMaxIter() { return (int)get(maxIter); }
120+
121+
public MyJavaLogisticRegression() {
122+
setMaxIter(100);
123+
}
124+
125+
// The parameter setter is in this class since it should return type MyJavaLogisticRegression.
126+
MyJavaLogisticRegression setMaxIter(int value) {
127+
return (MyJavaLogisticRegression)set(maxIter, value);
128+
}
129+
130+
// This method is used by fit().
131+
// In Java, we have to make it public since Java does not understand Scala's protected modifier.
132+
public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap) {
133+
// Extract columns from data using helper method.
134+
JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset, paramMap).toJavaRDD();
135+
136+
// Do learning to estimate the weight vector.
137+
int numFeatures = oldDataset.take(1).get(0).features().size();
138+
Vector weights = Vectors.zeros(numFeatures); // Learning would happen here.
139+
140+
// Create a model, and return it.
141+
return new MyJavaLogisticRegressionModel(this, paramMap, weights);
142+
}
143+
}
144+
145+
/**
146+
* Example of defining a type of {@link ClassificationModel}.
147+
*
148+
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
149+
*/
150+
class MyJavaLogisticRegressionModel
151+
extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> implements Params {
152+
153+
private MyJavaLogisticRegression parent_;
154+
public MyJavaLogisticRegression parent() { return parent_; }
155+
156+
private ParamMap fittingParamMap_;
157+
public ParamMap fittingParamMap() { return fittingParamMap_; }
158+
159+
private Vector weights_;
160+
public Vector weights() { return weights_; }
161+
162+
public MyJavaLogisticRegressionModel(
163+
MyJavaLogisticRegression parent_,
164+
ParamMap fittingParamMap_,
165+
Vector weights_) {
166+
this.parent_ = parent_;
167+
this.fittingParamMap_ = fittingParamMap_;
168+
this.weights_ = weights_;
169+
}
170+
171+
// This uses the default implementation of transform(), which reads column "features" and outputs
172+
// columns "prediction" and "rawPrediction."
173+
174+
// This uses the default implementation of predict(), which chooses the label corresponding to
175+
// the maximum value returned by [[predictRaw()]].
176+
177+
/**
178+
* Raw prediction for each possible label.
179+
* The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives
180+
* a measure of confidence in each possible label (where larger = more confident).
181+
* This internal method is used to implement [[transform()]] and output [[rawPredictionCol]].
182+
*
183+
* @return vector where element i is the raw prediction for label i.
184+
* This raw prediction may be any real number, where a larger value indicates greater
185+
* confidence for that label.
186+
*
187+
* In Java, we have to make this method public since Java does not understand Scala's protected
188+
* modifier.
189+
*/
190+
public Vector predictRaw(Vector features) {
191+
double margin = BLAS.dot(features, weights_);
192+
// There are 2 classes (binary classification), so we return a length-2 vector,
193+
// where index i corresponds to class i (i = 0, 1).
194+
return Vectors.dense(-margin, margin);
195+
}
196+
197+
/**
198+
* Number of classes the label can take. 2 indicates binary classification.
199+
*/
200+
public int numClasses() { return 2; }
201+
202+
/**
203+
* Create a copy of the model.
204+
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.
205+
* <p/>
206+
* This is used for the defaul implementation of [[transform()]].
207+
*
208+
* In Java, we have to make this method public since Java does not understand Scala's protected
209+
* modifier.
210+
*/
211+
public MyJavaLogisticRegressionModel copy() {
212+
MyJavaLogisticRegressionModel m =
213+
new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_);
214+
Params$.MODULE$.inheritValues(this.paramMap(), this, m);
215+
return m;
216+
}
217+
}

examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,5 +107,7 @@ public static void main(String[] args) {
107107
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
108108
+ ", prediction=" + r.get(3));
109109
}
110+
111+
jsc.stop();
110112
}
111113
}

examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,5 +88,7 @@ public static void main(String[] args) {
8888
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
8989
+ ", prediction=" + r.get(3));
9090
}
91+
92+
jsc.stop();
9193
}
9294
}

mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
2929
/**
3030
* :: DeveloperApi ::
3131
* Params for classification.
32+
*
33+
* NOTE: This is currently private[spark] but will be made public later once it is stabilized.
3234
*/
3335
@DeveloperApi
34-
trait ClassifierParams extends PredictorParams
36+
private[spark] trait ClassifierParams extends PredictorParams
3537
with HasRawPredictionCol {
3638

3739
override protected def validateAndTransformSchema(
@@ -53,9 +55,11 @@ trait ClassifierParams extends PredictorParams
5355
* @tparam FeaturesType Type of input features. E.g., [[Vector]]
5456
* @tparam Learner Concrete Estimator type
5557
* @tparam M Concrete Model type
58+
*
59+
* NOTE: This is currently private[spark] but will be made public later once it is stabilized.
5660
*/
5761
@AlphaComponent
58-
abstract class Classifier[
62+
private[spark] abstract class Classifier[
5963
FeaturesType,
6064
Learner <: Classifier[FeaturesType, Learner, M],
6165
M <: ClassificationModel[FeaturesType, M]]
@@ -75,8 +79,11 @@ abstract class Classifier[
7579
*
7680
* @tparam FeaturesType Type of input features. E.g., [[Vector]]
7781
* @tparam M Concrete Model type
82+
*
83+
* NOTE: This is currently private[spark] but will be made public later once it is stabilized.
7884
*/
7985
@AlphaComponent
86+
private[spark]
8087
abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]]
8188
extends PredictionModel[FeaturesType, M] with ClassifierParams {
8289

@@ -161,7 +168,7 @@ private[ml] object ClassificationModel {
161168
* should already be done.
162169
* @return (number of columns added, transformed dataset)
163170
*/
164-
private[ml] def transformColumnsImpl[FeaturesType](
171+
def transformColumnsImpl[FeaturesType](
165172
dataset: DataFrame,
166173
model: ClassificationModel[FeaturesType, _],
167174
map: ParamMap): (Int, DataFrame) = {

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ package org.apache.spark.ml.classification
2020
import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml.param._
2222
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
23-
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
23+
import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors}
2424
import org.apache.spark.sql.DataFrame
25+
import org.apache.spark.sql.Dsl._
26+
import org.apache.spark.sql.types.DoubleType
2527
import org.apache.spark.storage.StorageLevel
2628

2729

@@ -102,13 +104,82 @@ class LogisticRegressionModel private[ml] (
102104
1.0 / (1.0 + math.exp(-m))
103105
}
104106

107+
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
108+
// This is overridden (a) to be more efficient (avoiding re-computing values when creating
109+
// multiple output columns) and (b) to handle threshold, which the abstractions do not use.
110+
// TODO: We should abstract away the steps defined by UDFs below so that the abstractions
111+
// can call whichever UDFs are needed to create the output columns.
112+
113+
// Check schema
114+
transformSchema(dataset.schema, paramMap, logging = true)
115+
116+
val map = this.paramMap ++ paramMap
117+
118+
// Output selected columns only.
119+
// This is a bit complicated since it tries to avoid repeated computation.
120+
// rawPrediction (-margin, margin)
121+
// probability (1.0-score, score)
122+
// prediction (max margin)
123+
var tmpData = dataset
124+
var numColsOutput = 0
125+
if (map(rawPredictionCol) != "") {
126+
val features2raw: Vector => Vector = (features) => predictRaw(features)
127+
tmpData = tmpData.select($"*",
128+
callUDF(features2raw, new VectorUDT, col(map(featuresCol))).as(map(rawPredictionCol)))
129+
numColsOutput += 1
130+
}
131+
if (map(probabilityCol) != "") {
132+
if (map(rawPredictionCol) != "") {
133+
val raw2prob: Vector => Vector = { (rawPreds: Vector) =>
134+
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
135+
Vectors.dense(1.0 - prob1, prob1)
136+
}
137+
tmpData = tmpData.select($"*",
138+
callUDF(raw2prob, new VectorUDT, col(map(rawPredictionCol))).as(map(probabilityCol)))
139+
} else {
140+
val features2prob: Vector => Vector = (features: Vector) => predictProbabilities(features)
141+
tmpData = tmpData.select($"*",
142+
callUDF(features2prob, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol)))
143+
}
144+
numColsOutput += 1
145+
}
146+
if (map(predictionCol) != "") {
147+
val t = map(threshold)
148+
if (map(probabilityCol) != "") {
149+
val predict: Vector => Double = { probs: Vector =>
150+
if (probs(1) > t) 1.0 else 0.0
151+
}
152+
tmpData = tmpData.select($"*",
153+
callUDF(predict, DoubleType, col(map(probabilityCol))).as(map(predictionCol)))
154+
} else if (map(rawPredictionCol) != "") {
155+
val predict: Vector => Double = { rawPreds: Vector =>
156+
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
157+
if (prob1 > t) 1.0 else 0.0
158+
}
159+
tmpData = tmpData.select($"*",
160+
callUDF(predict, DoubleType, col(map(rawPredictionCol))).as(map(predictionCol)))
161+
} else {
162+
val predict: Vector => Double = (features: Vector) => this.predict(features)
163+
tmpData = tmpData.select($"*",
164+
callUDF(predict, DoubleType, col(map(featuresCol))).as(map(predictionCol)))
165+
}
166+
numColsOutput += 1
167+
}
168+
if (numColsOutput == 0) {
169+
this.logWarning(s"$uid: LogisticRegressionModel.transform() was called as NOOP" +
170+
" since no output columns were set.")
171+
}
172+
tmpData
173+
}
174+
105175
override val numClasses: Int = 2
106176

107177
/**
108178
* Predict label for the given feature vector.
109179
* The behavior of this can be adjusted using [[threshold]].
110180
*/
111181
override protected def predict(features: Vector): Double = {
182+
println(s"LR.predict with threshold: ${paramMap(threshold)}")
112183
if (score(features) > paramMap(threshold)) 1 else 0
113184
}
114185

0 commit comments

Comments
 (0)