Skip to content

Commit 135ab72

Browse files
committed
merge glm
2 parents 3f346ba + 0e57aa4 commit 135ab72

File tree

5 files changed

+65
-26
lines changed

5 files changed

+65
-26
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
4242
* @param weightMatrix Column vector containing the weights of the model
4343
* @param intercept Intercept of the model.
4444
*/
45-
def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double
45+
protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double
4646

4747
/**
4848
* Predict values for the given data set using the model trained.
@@ -116,6 +116,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
116116
run(input, initialWeights)
117117
}
118118

119+
/** Prepends one to the input vector. */
119120
private def prependOne(vector: Vector): Vector = {
120121
val vectorWithIntercept = vector match {
121122
case dv: BDV[Double] => BDV.vertcat(BDV.ones(1), dv)
@@ -154,8 +155,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
154155
val intercept = if (addIntercept) brzWeightsWithIntercept(0) else 0.0
155156
val brzWeights = if (addIntercept) brzWeightsWithIntercept(1 to -1) else brzWeightsWithIntercept
156157

157-
val model = createModel(Vectors.fromBreeze(brzWeights), intercept)
158-
159-
model
158+
createModel(Vectors.fromBreeze(brzWeights), intercept)
160159
}
161160
}

mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ class LassoModel(
3636
extends GeneralizedLinearModel(weights, intercept)
3737
with RegressionModel with Serializable {
3838

39-
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
40-
intercept: Double) = {
39+
override protected def predictPoint(
40+
dataMatrix: DoubleMatrix,
41+
weightMatrix: DoubleMatrix,
42+
intercept: Double): Double = {
4143
dataMatrix.dot(weightMatrix) + intercept
4244
}
4345
}
@@ -66,7 +68,7 @@ class LassoWithSGD private (
6668
.setMiniBatchFraction(miniBatchFraction)
6769

6870
// We don't want to penalize the intercept, so set this to false.
69-
setIntercept(false)
71+
super.setIntercept(false)
7072

7173
var yMean = 0.0
7274
var xColMean: DoubleMatrix = _
@@ -77,10 +79,16 @@ class LassoWithSGD private (
7779
*/
7880
def this() = this(1.0, 100, 1.0, 1.0)
7981

80-
def createModel(weights: Array[Double], intercept: Double) = {
81-
val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
82+
override def setIntercept(addIntercept: Boolean): this.type = {
83+
// TODO: Support adding intercept.
84+
if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
85+
this
86+
}
87+
88+
override protected def createModel(weights: Array[Double], intercept: Double) = {
89+
val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
8290
val weightsScaled = weightsMat.div(xColSd)
83-
val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0))
91+
val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)
8492

8593
new LassoModel(weightsScaled.data, interceptScaled)
8694
}

mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,14 @@ import org.jblas.DoubleMatrix
3131
* @param intercept Intercept computed for this model.
3232
*/
3333
class LinearRegressionModel(
34-
override val weights: Array[Double],
35-
override val intercept: Double)
36-
extends GeneralizedLinearModel(weights, intercept)
37-
with RegressionModel with Serializable {
38-
39-
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
40-
intercept: Double) = {
34+
override val weights: Array[Double],
35+
override val intercept: Double)
36+
extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable {
37+
38+
override protected def predictPoint(
39+
dataMatrix: DoubleMatrix,
40+
weightMatrix: DoubleMatrix,
41+
intercept: Double): Double = {
4142
dataMatrix.dot(weightMatrix) + intercept
4243
}
4344
}
@@ -55,8 +56,7 @@ class LinearRegressionWithSGD private (
5556
var stepSize: Double,
5657
var numIterations: Int,
5758
var miniBatchFraction: Double)
58-
extends GeneralizedLinearAlgorithm[LinearRegressionModel]
59-
with Serializable {
59+
extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable {
6060

6161
val gradient = new LeastSquaresGradient()
6262
val updater = new SimpleUpdater()
@@ -69,7 +69,7 @@ class LinearRegressionWithSGD private (
6969
*/
7070
def this() = this(1.0, 100, 1.0)
7171

72-
def createModel(weights: Array[Double], intercept: Double) = {
72+
override protected def createModel(weights: Array[Double], intercept: Double) = {
7373
new LinearRegressionModel(weights, intercept)
7474
}
7575
}

mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ class RidgeRegressionModel(
3535
extends GeneralizedLinearModel(weights, intercept)
3636
with RegressionModel with Serializable {
3737

38-
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
39-
intercept: Double) = {
38+
override protected def predictPoint(
39+
dataMatrix: DoubleMatrix,
40+
weightMatrix: DoubleMatrix,
41+
intercept: Double): Double = {
4042
dataMatrix.dot(weightMatrix) + intercept
4143
}
4244
}
@@ -66,7 +68,7 @@ class RidgeRegressionWithSGD private (
6668
.setMiniBatchFraction(miniBatchFraction)
6769

6870
// We don't want to penalize the intercept in RidgeRegression, so set this to false.
69-
setIntercept(false)
71+
super.setIntercept(false)
7072

7173
var yMean = 0.0
7274
var xColMean: DoubleMatrix = _
@@ -77,8 +79,14 @@ class RidgeRegressionWithSGD private (
7779
*/
7880
def this() = this(1.0, 100, 1.0, 1.0)
7981

80-
def createModel(weights: Array[Double], intercept: Double) = {
81-
val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
82+
override def setIntercept(addIntercept: Boolean): this.type = {
83+
// TODO: Support adding intercept.
84+
if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
85+
this
86+
}
87+
88+
override protected def createModel(weights: Array[Double], intercept: Double) = {
89+
val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
8290
val weightsScaled = weightsMat.div(xColSd)
8391
val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)
8492

mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20-
import org.scalatest.BeforeAndAfterAll
2120
import org.scalatest.FunSuite
2221

2322
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
@@ -57,4 +56,29 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
5756
// Test prediction on Array.
5857
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
5958
}
59+
60+
// Test if we can correctly learn Y = 10*X1 + 10*X2
61+
test("linear regression without intercept") {
62+
val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
63+
0.0, Array(10.0, 10.0), 100, 42), 2).cache()
64+
val linReg = new LinearRegressionWithSGD().setIntercept(false)
65+
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
66+
67+
val model = linReg.run(testRDD)
68+
69+
assert(model.intercept === 0.0)
70+
assert(model.weights.length === 2)
71+
assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0)
72+
assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0)
73+
74+
val validationData = LinearDataGenerator.generateLinearInput(
75+
0.0, Array(10.0, 10.0), 100, 17)
76+
val validationRDD = sc.parallelize(validationData, 2).cache()
77+
78+
// Test prediction on RDD.
79+
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
80+
81+
// Test prediction on Array.
82+
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
83+
}
6084
}

0 commit comments

Comments
 (0)