Skip to content

Commit f0fe616

Browse files
committed
add a test for sparse linear regression
1 parent 44733e1 commit f0fe616

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression
2020
import org.scalatest.FunSuite
2121

2222
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
23+
import org.apache.spark.mllib.linalg.Vectors
2324

2425
class LinearRegressionSuite extends FunSuite with LocalSparkContext {
2526

@@ -84,4 +85,37 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
8485
// Test prediction on Array.
8586
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
8687
}
88+
89+
// Test if we can correctly learn Y = 10*X1 + 10*X10000
90+
test("sparse linear regression without intercept") {
91+
val denseRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42), 2)
92+
val sparseRDD = denseRDD.map { case LabeledPoint(label, v) =>
93+
val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1))))
94+
LabeledPoint(label, sv)
95+
}.cache()
96+
val linReg = new LinearRegressionWithSGD().setIntercept(false)
97+
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
98+
99+
val model = linReg.run(sparseRDD)
100+
101+
assert(model.intercept === 0.0)
102+
103+
val weights = model.weights
104+
assert(weights.size === 10000)
105+
assert(weights(0) >= 9.0 && weights(0) <= 11.0)
106+
assert(weights(9999) >= 9.0 && weights(9999) <= 11.0)
107+
108+
val validationData = LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 17)
109+
val sparseValidationData = validationData.map { case LabeledPoint(label, v) =>
110+
val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1))))
111+
LabeledPoint(label, sv)
112+
}
113+
val sparseValidationRDD = sc.parallelize(sparseValidationData, 2)
114+
115+
// Test prediction on RDD.
116+
validatePrediction(model.predict(sparseValidationRDD.map(_.features)).collect(), sparseValidationData)
117+
118+
// Test prediction on Array.
119+
validatePrediction(sparseValidationData.map(row => model.predict(row.features)), sparseValidationData)
120+
}
87121
}

0 commit comments

Comments
 (0)