Skip to content

Commit 1c61723

Browse files
committed
* Made ProbabilisticClassificationModel into a subclass of ClassificationModel. Also introduced ProbabilisticClassifier.
* This was to support output column “probabilityCol” in transform(). * SPARK-4942 : ML Transformers should allow output cols to be turned on,off * Update validateAndTransformSchema * Update transform * Update based on design review * Make prediction API protected, but add output columns * Remove training API * LogisticRegression: * Changed output column “score” to “probability” in logreg. * I also implemented transform() to avoid repeated computation. This improves upon the default implementation in ProbabilisticClassificationModel. However, it’s a lot of code, so I would be fine with removing it. There is also a question of whether all algorithms should implement a method which would allow the ProbabilisticClassificationModel.transform implementation to avoid repeated computation: * protected def raw2prob(rawPredictions: Vector): Vector = // compute probabilities from raw predictions * trait Params: * Changed set() and get() from private[ml] to protected. This was needed for the example of defining a class from outside of the MLlib namespace. * VectorUDT: Changed from private[spark] to public. This is needed for outside users to write their own validateAndTransformSchema() methods using vectors. * Add example of defining class from outside of the MLlib namespace. * Scala
1 parent 4e2f711 commit 1c61723

20 files changed

+820
-338
lines changed

examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ public static void main(String[] args) {
116116

117117
// Make predictions on test documents. cvModel uses the best model found (lrModel).
118118
cvModel.transform(test).registerTempTable("prediction");
119-
DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
119+
DataFrame predictions = jsql.sql("SELECT id, text, probability, prediction FROM prediction");
120120
for (Row r: predictions.collect()) {
121-
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
121+
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
122122
+ ", prediction=" + r.get(3));
123123
}
124124
}

examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ public static void main(String[] args) {
8181

8282
// One can also combine ParamMaps.
8383
ParamMap paramMap2 = new ParamMap();
84-
paramMap2.put(lr.scoreCol().w("probability")); // Change output column name
84+
paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name
8585
ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);
8686

8787
// Now learn a new model using the paramMapCombined parameters.
@@ -98,8 +98,8 @@ public static void main(String[] args) {
9898

9999
// Make predictions on test documents using the Transformer.transform() method.
100100
// LogisticRegression.transform will only use the 'features' column.
101-
// Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
102-
// column since we renamed the lr.scoreCol parameter previously.
101+
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
102+
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
103103
model2.transform(test).registerTempTable("results");
104104
DataFrame results =
105105
jsql.sql("SELECT features, label, probability, prediction FROM results");

examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ public static void main(String[] args) {
8585
model.transform(test).registerTempTable("prediction");
8686
DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
8787
for (Row r: predictions.collect()) {
88-
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
88+
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
8989
+ ", prediction=" + r.get(3));
9090
}
9191
}

examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.ml.classification.LogisticRegression
2323
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
2424
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
2525
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
26+
import org.apache.spark.mllib.linalg.Vector
2627
import org.apache.spark.sql.{Row, SQLContext}
2728

2829
/**
@@ -100,10 +101,10 @@ object CrossValidatorExample {
100101

101102
// Make predictions on test documents. cvModel uses the best model found (lrModel).
102103
cvModel.transform(test)
103-
.select("id", "text", "score", "prediction")
104+
.select('id, 'text, 'probability, 'prediction)
104105
.collect()
105-
.foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
106-
println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
106+
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
107+
println("(" + id + ", " + text + ") --> prob=" + prob + ", prediction=" + prediction)
107108
}
108109

109110
sc.stop()
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.ml
19+
20+
import org.apache.spark.{SparkConf, SparkContext}
21+
import org.apache.spark.SparkContext._
22+
import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel}
23+
import org.apache.spark.ml.param.{Params, IntParam, ParamMap}
24+
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
25+
import org.apache.spark.mllib.regression.LabeledPoint
26+
import org.apache.spark.sql.{DataType, SchemaRDD, Row, SQLContext}
27+
28+
/**
29+
* A simple example demonstrating how to write your own learning algorithm using Estimator,
30+
* Transformer, and other abstractions.
31+
* This mimics [[org.apache.spark.ml.classification.LogisticRegression]].
32+
* Run with
33+
* {{{
34+
* bin/run-example ml.DeveloperApiExample
35+
* }}}
36+
*/
37+
object DeveloperApiExample {
38+
39+
def main(args: Array[String]) {
40+
val conf = new SparkConf().setAppName("DeveloperApiExample")
41+
val sc = new SparkContext(conf)
42+
val sqlContext = new SQLContext(sc)
43+
import sqlContext._
44+
45+
// Prepare training data.
46+
// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans
47+
// into SchemaRDDs, where it uses the bean metadata to infer the schema.
48+
val training = sparkContext.parallelize(Seq(
49+
LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
50+
LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
51+
LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
52+
LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))))
53+
54+
// Create a LogisticRegression instance. This instance is an Estimator.
55+
val lr = new MyLogisticRegression()
56+
// Print out the parameters, documentation, and any default values.
57+
println("MyLogisticRegression parameters:\n" + lr.explainParams() + "\n")
58+
59+
// We may set parameters using setter methods.
60+
lr.setMaxIter(10)
61+
62+
// Learn a LogisticRegression model. This uses the parameters stored in lr.
63+
val model = lr.fit(training)
64+
65+
// Prepare test data.
66+
val test = sparkContext.parallelize(Seq(
67+
LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
68+
LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
69+
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))
70+
71+
// Make predictions on test data.
72+
val sumPredictions: Double = model.transform(test)
73+
.select('features, 'label, 'prediction)
74+
.collect()
75+
.map { case Row(features: Vector, label: Double, prediction: Double) =>
76+
prediction
77+
}.sum
78+
assert(sumPredictions == 0.0,
79+
"MyLogisticRegression predicted something other than 0, even though all weights are 0!")
80+
}
81+
}
82+
83+
/**
84+
* Example of defining a parameter trait for a user-defined type of [[Classifier]].
85+
*
86+
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
87+
*/
88+
private trait MyLogisticRegressionParams extends ClassifierParams {
89+
90+
/** param for max number of iterations */
91+
val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
92+
def getMaxIter: Int = get(maxIter)
93+
}
94+
95+
/**
96+
* Example of defining a type of [[Classifier]].
97+
*
98+
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
99+
*/
100+
private class MyLogisticRegression
101+
extends Classifier[Vector, MyLogisticRegression, MyLogisticRegressionModel]
102+
with MyLogisticRegressionParams {
103+
104+
setMaxIter(100) // Initialize
105+
106+
def setMaxIter(value: Int): this.type = set(maxIter, value)
107+
108+
override def fit(dataset: SchemaRDD, paramMap: ParamMap): MyLogisticRegressionModel = {
109+
// Check schema (types). This allows early failure before running the algorithm.
110+
transformSchema(dataset.schema, paramMap, logging = true)
111+
112+
// Extract columns from data using helper method.
113+
val oldDataset = extractLabeledPoints(dataset, paramMap)
114+
115+
// Combine given parameters with the embedded parameters, where the given paramMap overrides
116+
// any embedded settings.
117+
val map = this.paramMap ++ paramMap
118+
119+
// Do learning to estimate the weight vector.
120+
val numFeatures = oldDataset.take(1)(0).features.size
121+
val weights = Vectors.zeros(numFeatures) // Learning would happen here.
122+
123+
// Create a model to return.
124+
val lrm = new MyLogisticRegressionModel(this, map, weights)
125+
126+
// Copy model params.
127+
// An Estimator stores the parameters for the Model it produces, and this copies any relevant
128+
// parameters to the model.
129+
Params.inheritValues(map, this, lrm)
130+
131+
// Return the learned model.
132+
lrm
133+
}
134+
135+
/**
136+
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
137+
* This is used by [[ClassifierParams.validateAndTransformSchema()]] to check the input data.
138+
*/
139+
override protected def featuresDataType: DataType = new VectorUDT
140+
}
141+
142+
/**
143+
* Example of defining a type of [[ClassificationModel]].
144+
*
145+
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
146+
*/
147+
private class MyLogisticRegressionModel(
148+
override val parent: MyLogisticRegression,
149+
override val fittingParamMap: ParamMap,
150+
val weights: Vector)
151+
extends ClassificationModel[Vector, MyLogisticRegressionModel]
152+
with MyLogisticRegressionParams {
153+
154+
// This uses the default implementation of transform(), which reads column "features" and outputs
155+
// columns "prediction" and "rawPrediction."
156+
157+
// This uses the default implementation of predict(), which chooses the label corresponding to
158+
// the maximum value returned by [[predictRaw()]].
159+
160+
/**
161+
* Raw prediction for each possible label.
162+
* The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives
163+
* a measure of confidence in each possible label (where larger = more confident).
164+
* This internal method is used to implement [[transform()]] and output [[rawPredictionCol]].
165+
*
166+
* @return vector where element i is the raw prediction for label i.
167+
* This raw prediction may be any real number, where a larger value indicates greater
168+
* confidence for that label.
169+
*/
170+
override protected def predictRaw(features: Vector): Vector = {
171+
val margin = BLAS.dot(features, weights)
172+
// There are 2 classes (binary classification), so we return a length-2 vector,
173+
// where index i corresponds to class i (i = 0, 1).
174+
Vectors.dense(-margin, margin)
175+
}
176+
177+
/** Number of classes the label can take. 2 indicates binary classification. */
178+
override val numClasses: Int = 2
179+
180+
/**
181+
* Create a copy of the model.
182+
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.
183+
*
184+
* This is used for the defaul implementation of [[transform()]].
185+
*/
186+
override protected def copy(): MyLogisticRegressionModel = {
187+
val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights)
188+
Params.inheritValues(this.paramMap, this, m)
189+
m
190+
}
191+
192+
/**
193+
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
194+
* This is used by [[ClassifierParams.validateAndTransformSchema()]] to check the input data.
195+
*/
196+
override protected def featuresDataType: DataType = new VectorUDT
197+
}

examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,26 +72,26 @@ object SimpleParamsExample {
7272
paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params.
7373

7474
// One can also combine ParamMaps.
75-
val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Change output column name
75+
val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name
7676
val paramMapCombined = paramMap ++ paramMap2
7777

7878
// Now learn a new model using the paramMapCombined parameters.
7979
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
8080
val model2 = lr.fit(training, paramMapCombined)
8181
println("Model 2 was fit using parameters: " + model2.fittingParamMap)
8282

83-
// Prepare test documents.
84-
val test = sc.parallelize(Seq(
83+
// Prepare test data.
84+
val test = sparkContext.parallelize(Seq(
8585
LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
8686
LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
8787
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))
8888

89-
// Make predictions on test documents using the Transformer.transform() method.
89+
// Make predictions on test data using the Transformer.transform() method.
9090
// LogisticRegression.transform will only use the 'features' column.
91-
// Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
92-
// column since we renamed the lr.scoreCol parameter previously.
91+
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
92+
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
9393
model2.transform(test)
94-
.select("features", "label", "probability", "prediction")
94+
.select('features, 'label, 'myProbability, 'prediction)
9595
.collect()
9696
.foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) =>
9797
println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction)

examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.{SparkConf, SparkContext}
2323
import org.apache.spark.ml.Pipeline
2424
import org.apache.spark.ml.classification.LogisticRegression
2525
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
26+
import org.apache.spark.mllib.linalg.Vector
2627
import org.apache.spark.sql.{Row, SQLContext}
2728

2829
@BeanInfo
@@ -79,10 +80,10 @@ object SimpleTextClassificationPipeline {
7980

8081
// Make predictions on test documents.
8182
model.transform(test)
82-
.select("id", "text", "score", "prediction")
83+
.select('id, 'text, 'probability, 'prediction)
8384
.collect()
84-
.foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
85-
println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
85+
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
86+
println("(" + id + ", " + text + ") --> prob=" + prob + ", prediction=" + prediction)
8687
}
8788

8889
sc.stop()

mllib/src/main/scala/org/apache/spark/ml/LabeledPoint.scala

Lines changed: 0 additions & 52 deletions
This file was deleted.

0 commit comments

Comments
 (0)