Skip to content

Commit d7f629f

Browse files
committed
fix a bug in GLM when intercept is not used
1 parent 8237df8 commit d7f629f

File tree

2 files changed

+37
-10
lines changed

2 files changed

+37
-10
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,25 +136,28 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
136136

137137
// Prepend an extra variable consisting of all 1.0's for the intercept.
138138
val data = if (addIntercept) {
139-
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features.+:(1.0)))
139+
input.map(labeledPoint => (labeledPoint.label, 1.0 +: labeledPoint.features))
140140
} else {
141141
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
142142
}
143143

144144
val initialWeightsWithIntercept = if (addIntercept) {
145-
initialWeights.+:(1.0)
145+
0.0 +: initialWeights
146146
} else {
147147
initialWeights
148148
}
149149

150-
val weights = optimizer.optimize(data, initialWeightsWithIntercept)
151-
val intercept = weights(0)
152-
val weightsScaled = weights.tail
150+
val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)
153151

154-
val model = createModel(weightsScaled, intercept)
152+
val (intercept, weights) = if (addIntercept) {
153+
(weightsWithIntercept(0), weightsWithIntercept.tail)
154+
} else {
155+
(0.0, weightsWithIntercept)
156+
}
157+
158+
logInfo("Final weights " + weights.mkString(","))
159+
logInfo("Final intercept " + intercept)
155160

156-
logInfo("Final model weights " + model.weights.mkString(","))
157-
logInfo("Final model intercept " + model.intercept)
158-
model
161+
createModel(weights, intercept)
159162
}
160163
}

mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20-
import org.scalatest.BeforeAndAfterAll
2120
import org.scalatest.FunSuite
2221

2322
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
@@ -57,4 +56,29 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
5756
// Test prediction on Array.
5857
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
5958
}
59+
60+
// Test if we can correctly learn Y = 10*X1 + 10*X2
61+
test("linear regression without intercept") {
62+
val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
63+
0.0, Array(10.0, 10.0), 100, 42), 2).cache()
64+
val linReg = new LinearRegressionWithSGD().setIntercept(false)
65+
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
66+
67+
val model = linReg.run(testRDD)
68+
69+
assert(model.intercept === 0.0)
70+
assert(model.weights.length === 2)
71+
assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0)
72+
assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0)
73+
74+
val validationData = LinearDataGenerator.generateLinearInput(
75+
0.0, Array(10.0, 10.0), 100, 17)
76+
val validationRDD = sc.parallelize(validationData, 2).cache()
77+
78+
// Test prediction on RDD.
79+
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
80+
81+
// Test prediction on Array.
82+
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
83+
}
6084
}

0 commit comments

Comments
 (0)