Skip to content

Commit 57d54ab

Browse files
committed
* Changed semantics of Predictor.train() to merge the given paramMap with the embedded paramMap.
* remove threshold_internal from logreg * Added Predictor.copy() * Extended LogisticRegressionSuite
1 parent e433872 commit 57d54ab

File tree

4 files changed

+78
-25
lines changed

4 files changed

+78
-25
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressi
7676
* These values override any specified in this Estimator's embedded ParamMap.
7777
*/
7878
def train(dataset: RDD[LabeledPoint], paramMap: ParamMap): LogisticRegressionModel = {
79+
val map = this.paramMap ++ paramMap
7980
val oldDataset = dataset.map { case LabeledPoint(label: Double, features: Vector, weight) =>
8081
org.apache.spark.mllib.regression.LabeledPoint(label, features)
8182
}
@@ -86,14 +87,13 @@ class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressi
8687
}
8788
val lr = new LogisticRegressionWithLBFGS
8889
lr.optimizer
89-
.setRegParam(paramMap(regParam))
90-
.setNumIterations(paramMap(maxIter))
90+
.setRegParam(map(regParam))
91+
.setNumIterations(map(maxIter))
9192
val model = lr.run(oldDataset)
92-
val lrm = new LogisticRegressionModel(this, paramMap, model.weights, model.intercept)
93+
val lrm = new LogisticRegressionModel(this, map, model.weights, model.intercept)
9394
if (handlePersistence) {
9495
oldDataset.unpersist()
9596
}
96-
lrm.setThreshold(paramMap(threshold))
9797
lrm
9898
}
9999
}
@@ -115,18 +115,9 @@ class LogisticRegressionModel private[ml] (
115115

116116
setThreshold(0.5)
117117

118-
def setThreshold(value: Double): this.type = {
119-
this.threshold_internal = value
120-
set(threshold, value)
121-
}
118+
def setThreshold(value: Double): this.type = set(threshold, value)
122119
def setScoreCol(value: String): this.type = set(scoreCol, value)
123120

124-
/**
125-
* Store for faster test-time prediction.
126-
* Initialized to threshold in fittingParamMap if exists, else default threshold.
127-
*/
128-
private var threshold_internal: Double = fittingParamMap.get(threshold).getOrElse(getThreshold)
129-
130121
private val margin: Vector => Double = (features) => {
131122
BLAS.dot(features, weights) + intercept
132123
}
@@ -142,7 +133,8 @@ class LogisticRegressionModel private[ml] (
142133
val scoreFunction = udf { v: Vector =>
143134
val margin = BLAS.dot(v, weights)
144135
1.0 / (1.0 + math.exp(-margin))
145-
val t = threshold_internal
136+
}
137+
val t = map(threshold)
146138
val predictFunction: Double => Double = (score) => { if (score > t) 1.0 else 0.0 }
147139
dataset
148140
.select($"*", scoreFunction(col(map(featuresCol))).as(map(scoreCol)))
@@ -151,12 +143,14 @@ class LogisticRegressionModel private[ml] (
151143

152144
override val numClasses: Int = 2
153145

146+
// TODO: Override batch predict() for efficiency.
147+
154148
/**
155149
* Predict label for the given feature vector.
156150
* The behavior of this can be adjusted using [[threshold]].
157151
*/
158152
override def predict(features: Vector): Double = {
159-
if (score(features) > threshold_internal) 1 else 0
153+
if (score(features) > paramMap(threshold)) 1 else 0
160154
}
161155

162156
override def predictProbabilities(features: Vector): Vector = {
@@ -168,4 +162,10 @@ class LogisticRegressionModel private[ml] (
168162
val m = margin(features)
169163
Vectors.dense(Array(-m, m))
170164
}
165+
166+
private[ml] override def copy(): LogisticRegressionModel = {
167+
val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept)
168+
Params.inheritValues(this.paramMap, this, m)
169+
m
170+
}
171171
}

mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,22 +128,33 @@ private[ml] abstract class PredictionModel[M <: PredictionModel[M]]
128128

129129
transformSchema(dataset.schema, paramMap, logging = true)
130130
val map = this.paramMap ++ paramMap
131+
val tmpModel = this.copy()
132+
Params.inheritValues(paramMap, parent, tmpModel)
131133
val pred: Vector => Double = (features) => {
132-
predict(features)
134+
tmpModel.predict(features)
133135
}
134136
dataset.select(Star(None), pred.call(map(featuresCol).attr) as map(predictionCol))
135137
}
136138

137139
/**
138-
* Default implementation.
139-
* Override for efficiency; e.g., this does not broadcast the model.
140+
* Default implementation using single-instance predict().
141+
*
142+
* Developers should override this for efficiency. E.g., this does not broadcast the model.
140143
*/
141-
def predict(dataset: RDD[Vector]): RDD[Double] = {
142-
dataset.map(predict)
144+
def predict(dataset: RDD[Vector], paramMap: ParamMap): RDD[Double] = {
145+
val tmpModel = this.copy()
146+
Params.inheritValues(paramMap, parent, tmpModel)
147+
dataset.map(tmpModel.predict)
143148
}
144149

145150
/**
146151
* Predict label for the given features.
147152
*/
148153
def predict(features: Vector): Double
154+
155+
/**
156+
* Create a copy of the model.
157+
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.
158+
*/
159+
private[ml] def copy(): M
149160
}

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.ml.regression
1919

2020
import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml.LabeledPoint
22-
import org.apache.spark.ml.param.{ParamMap, HasMaxIter, HasRegParam}
22+
import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam}
2323
import org.apache.spark.mllib.linalg.{BLAS, Vector}
2424
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
2525
import org.apache.spark.rdd.RDD
@@ -89,4 +89,10 @@ class LinearRegressionModel private[ml] (
8989
override def predict(features: Vector): Double = {
9090
BLAS.dot(features, weights) + intercept
9191
}
92+
93+
private[ml] override def copy(): LinearRegressionModel = {
94+
val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept)
95+
Params.inheritValues(this.paramMap, this, m)
96+
m
97+
}
9298
}

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ import org.scalatest.FunSuite
2121

2222
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
2323
import org.apache.spark.mllib.util.MLlibTestSparkContext
24-
import org.apache.spark.sql.{SQLContext, DataFrame}
24+
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
25+
2526

2627
class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
2728

@@ -32,21 +33,29 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
3233
super.beforeAll()
3334
sqlContext = new SQLContext(sc)
3435
dataset = sqlContext.createDataFrame(
35-
sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
36+
sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2))
3637
}
3738

38-
test("logistic regression") {
39+
test("logistic regression: default params") {
3940
val lr = new LogisticRegression
4041
val model = lr.fit(dataset)
4142
model.transform(dataset)
4243
.select("label", "prediction")
4344
.collect()
45+
// Check defaults
46+
assert(model.getThreshold === 0.5)
47+
assert(model.getFeaturesCol == "features")
48+
assert(model.getPredictionCol == "prediction")
49+
assert(model.getScoreCol == "score")
4450
}
4551

4652
test("logistic regression with setters") {
53+
// Set params, train, and check as many as we can.
4754
val lr = new LogisticRegression()
4855
.setMaxIter(10)
4956
.setRegParam(1.0)
57+
.setThreshold(0.6)
58+
.setScoreCol("probability")
5059
val model = lr.fit(dataset)
5160
model.transform(dataset, model.threshold -> 0.8) // overwrite threshold
5261
.select("label", "score", "prediction")
@@ -58,6 +67,33 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
5867
val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
5968
model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")
6069
.select("label", "probability", "prediction")
70+
assert(model.fittingParamMap.get(lr.maxIter) === Some(10))
71+
assert(model.fittingParamMap.get(lr.regParam) === Some(1.0))
72+
assert(model.fittingParamMap.get(lr.threshold) === Some(0.6))
73+
assert(model.getThreshold === 0.6)
74+
75+
// Modify model params, and check that they work.
76+
model.setThreshold(1.0)
77+
val predAllZero = model.transform(dataset)
78+
.select('prediction, 'probability)
6179
.collect()
80+
.map { case Row(pred: Double, prob: Double) => pred }
81+
assert(predAllZero.forall(_ === 0.0))
82+
// Call transform with params, and check that they work.
83+
val predNotAllZero =
84+
model.transform(dataset, model.threshold -> 0.0, model.scoreCol -> "myProb")
85+
.select('prediction, 'myProb)
86+
.collect()
87+
.map { case Row(pred: Double, prob: Double) => pred }
88+
assert(predNotAllZero.exists(_ !== 0.0))
89+
90+
// Call fit() with new params, and check as many as we can.
91+
val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4,
92+
lr.scoreCol -> "theProb")
93+
assert(model2.fittingParamMap.get(lr.maxIter) === Some(5))
94+
assert(model2.fittingParamMap.get(lr.regParam) === Some(0.1))
95+
assert(model2.fittingParamMap.get(lr.threshold) === Some(0.4))
96+
assert(model2.getThreshold === 0.4)
97+
assert(model2.getScoreCol == "theProb")
6298
}
6399
}

0 commit comments

Comments
 (0)