Skip to content

Commit 4ca5b1b

Browse files
committed
remove normalization from Lasso and update tests
1 parent f04fe8a commit 4ca5b1b

File tree

3 files changed

+31
-50
lines changed

3 files changed

+31
-50
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,11 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20-
import breeze.linalg.{Vector => BV}
21-
2220
import org.apache.spark.SparkContext
23-
import org.apache.spark.rdd.RDD
24-
import org.apache.spark.mllib.linalg.{Vector, Vectors}
21+
import org.apache.spark.mllib.linalg.Vector
2522
import org.apache.spark.mllib.optimization._
2623
import org.apache.spark.mllib.util.MLUtils
24+
import org.apache.spark.rdd.RDD
2725

2826
/**
2927
* Regression model trained using Lasso.
@@ -58,8 +56,7 @@ class LassoWithSGD private (
5856
var numIterations: Int,
5957
var regParam: Double,
6058
var miniBatchFraction: Double)
61-
extends GeneralizedLinearAlgorithm[LassoModel]
62-
with Serializable {
59+
extends GeneralizedLinearAlgorithm[LassoModel] with Serializable {
6360

6461
val gradient = new LeastSquaresGradient()
6562
val updater = new L1Updater()
@@ -71,10 +68,6 @@ class LassoWithSGD private (
7168
// We don't want to penalize the intercept, so set this to false.
7269
super.setIntercept(false)
7370

74-
private var yMean = 0.0
75-
private var xColMean: BV[Double] = _
76-
private var xColSd: BV[Double] = _
77-
7871
/**
7972
* Construct a Lasso object with default parameters
8073
*/
@@ -87,31 +80,7 @@ class LassoWithSGD private (
8780
}
8881

8982
override protected def createModel(weights: Vector, intercept: Double) = {
90-
val weightsMat = weights.toBreeze
91-
val weightsScaled = weightsMat :/ xColSd
92-
val interceptScaled = yMean - weightsMat.dot(xColMean :/ xColSd)
93-
94-
new LassoModel(Vectors.fromBreeze(weightsScaled), interceptScaled)
95-
}
96-
97-
override def run(input: RDD[LabeledPoint], initialWeights: Vector): LassoModel = {
98-
val nfeatures: Int = input.first.features.size
99-
val nexamples: Long = input.count()
100-
101-
// To avoid penalizing the intercept, we center and scale the data.
102-
val stats = MLUtils.computeStats(input, nfeatures, nexamples)
103-
yMean = stats._1
104-
xColMean = stats._2.toBreeze
105-
xColSd = stats._3.toBreeze
106-
107-
val normalizedData = input.map { point =>
108-
val yNormalized = point.label - yMean
109-
val featuresMat = point.features.toBreeze
110-
val featuresNormalized = (featuresMat - xColMean) :/ xColSd
111-
LabeledPoint(yNormalized, Vectors.fromBreeze(featuresNormalized))
112-
}
113-
114-
super.run(normalizedData, initialWeights)
83+
new LassoModel(weights, intercept)
11584
}
11685
}
11786

mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ int validatePrediction(List<LabeledPoint> validationData, LassoModel model) {
5959
@Test
6060
public void runLassoUsingConstructor() {
6161
int nPoints = 10000;
62-
double A = 2.0;
62+
double A = 0.0;
6363
double[] weights = {-1.5, 1.0e-2};
6464

6565
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
@@ -80,7 +80,7 @@ public void runLassoUsingConstructor() {
8080
@Test
8181
public void runLassoUsingStaticMethods() {
8282
int nPoints = 10000;
83-
double A = 2.0;
83+
double A = 0.0;
8484
double[] weights = {-1.5, 1.0e-2};
8585

8686
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,

mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,27 @@ class LassoSuite extends FunSuite with LocalSparkContext {
4040
val B = -1.5
4141
val C = 1.0e-2
4242

43-
val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42)
44-
45-
val testRDD = sc.parallelize(testData, 2)
46-
testRDD.cache()
43+
val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B, C), nPoints, 42)
44+
.map { case LabeledPoint(label, features) =>
45+
LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
46+
}
47+
val testRDD = sc.parallelize(testData, 2).cache()
4748

4849
val ls = new LassoWithSGD()
4950
ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40)
5051

5152
val model = ls.run(testRDD)
5253
val weight0 = model.weights(0)
5354
val weight1 = model.weights(1)
54-
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
55-
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
56-
assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
55+
val weight2 = model.weights(2)
56+
assert(weight0 >= 1.9 && weight0 <= 2.1, weight0 + " not in [1.9, 2.1]")
57+
assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]")
58+
assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]")
5759

5860
val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17)
61+
.map { case LabeledPoint(label, features) =>
62+
LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
63+
}
5964
val validationRDD = sc.parallelize(validationData, 2)
6065

6166
// Test prediction on RDD.
@@ -73,25 +78,32 @@ class LassoSuite extends FunSuite with LocalSparkContext {
7378
val C = 1.0e-2
7479

7580
val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B, C), nPoints, 42)
81+
.map { case LabeledPoint(label, features) =>
82+
LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
83+
}
7684

85+
val initialA = -1.0
7786
val initialB = -1.0
7887
val initialC = -1.0
79-
val initialWeights = Vectors.dense(initialB, initialC)
88+
val initialWeights = Vectors.dense(initialA, initialB, initialC)
8089

81-
val testRDD = sc.parallelize(testData, 2)
82-
testRDD.cache()
90+
val testRDD = sc.parallelize(testData, 2).cache()
8391

8492
val ls = new LassoWithSGD()
8593
ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40)
8694

8795
val model = ls.run(testRDD, initialWeights)
8896
val weight0 = model.weights(0)
8997
val weight1 = model.weights(1)
90-
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
91-
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
92-
assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
98+
val weight2 = model.weights(2)
99+
assert(weight0 >= 1.9 && weight0 <= 2.1, weight0 + " not in [1.9, 2.1]")
100+
assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]")
101+
assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]")
93102

94103
val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17)
104+
.map { case LabeledPoint(label, features) =>
105+
LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
106+
}
95107
val validationRDD = sc.parallelize(validationData,2)
96108

97109
// Test prediction on RDD.

0 commit comments

Comments
 (0)