@@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression
20
20
import org .scalatest .FunSuite
21
21
22
22
import org .apache .spark .mllib .util .{LinearDataGenerator , LocalSparkContext }
23
+ import org .apache .spark .mllib .linalg .Vectors
23
24
24
25
class LinearRegressionSuite extends FunSuite with LocalSparkContext {
25
26
@@ -84,4 +85,37 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
84
85
// Test prediction on Array.
85
86
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
86
87
}
88
+
89
+ // Test if we can correctly learn Y = 10*X1 + 10*X10000
90
+ test(" sparse linear regression without intercept" ) {
91
+ val denseRDD = sc.parallelize(LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 , 10.0 ), 100 , 42 ), 2 )
92
+ val sparseRDD = denseRDD.map { case LabeledPoint (label, v) =>
93
+ val sv = Vectors .sparse(10000 , Seq ((0 , v(0 )), (9999 , v(1 ))))
94
+ LabeledPoint (label, sv)
95
+ }.cache()
96
+ val linReg = new LinearRegressionWithSGD ().setIntercept(false )
97
+ linReg.optimizer.setNumIterations(1000 ).setStepSize(1.0 )
98
+
99
+ val model = linReg.run(sparseRDD)
100
+
101
+ assert(model.intercept === 0.0 )
102
+
103
+ val weights = model.weights
104
+ assert(weights.size === 10000 )
105
+ assert(weights(0 ) >= 9.0 && weights(0 ) <= 11.0 )
106
+ assert(weights(9999 ) >= 9.0 && weights(9999 ) <= 11.0 )
107
+
108
+ val validationData = LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 , 10.0 ), 100 , 17 )
109
+ val sparseValidationData = validationData.map { case LabeledPoint (label, v) =>
110
+ val sv = Vectors .sparse(10000 , Seq ((0 , v(0 )), (9999 , v(1 ))))
111
+ LabeledPoint (label, sv)
112
+ }
113
+ val sparseValidationRDD = sc.parallelize(sparseValidationData, 2 )
114
+
115
+ // Test prediction on RDD.
116
+ validatePrediction(model.predict(sparseValidationRDD.map(_.features)).collect(), sparseValidationData)
117
+
118
+ // Test prediction on Array.
119
+ validatePrediction(sparseValidationData.map(row => model.predict(row.features)), sparseValidationData)
120
+ }
87
121
}
0 commit comments