@@ -40,22 +40,27 @@ class LassoSuite extends FunSuite with LocalSparkContext {
4040 val B = - 1.5
4141 val C = 1.0e-2
4242
43- val testData = LinearDataGenerator .generateLinearInput(A , Array [Double ](B ,C ), nPoints, 42 )
44-
45- val testRDD = sc.parallelize(testData, 2 )
46- testRDD.cache()
43+ val testData = LinearDataGenerator .generateLinearInput(A , Array [Double ](B , C ), nPoints, 42 )
44+ .map { case LabeledPoint (label, features) =>
45+ LabeledPoint (label, Vectors .dense(1.0 +: features.toArray))
46+ }
47+ val testRDD = sc.parallelize(testData, 2 ).cache()
4748
4849 val ls = new LassoWithSGD ()
4950 ls.optimizer.setStepSize(1.0 ).setRegParam(0.01 ).setNumIterations(40 )
5051
5152 val model = ls.run(testRDD)
5253 val weight0 = model.weights(0 )
5354 val weight1 = model.weights(1 )
54- assert(model.intercept >= 1.9 && model.intercept <= 2.1 , model.intercept + " not in [1.9, 2.1]" )
55- assert(weight0 >= - 1.60 && weight0 <= - 1.40 , weight0 + " not in [-1.6, -1.4]" )
56- assert(weight1 >= - 1.0e-3 && weight1 <= 1.0e-3 , weight1 + " not in [-0.001, 0.001]" )
55+ val weight2 = model.weights(2 )
56+ assert(weight0 >= 1.9 && weight0 <= 2.1 , weight0 + " not in [1.9, 2.1]" )
57+ assert(weight1 >= - 1.60 && weight1 <= - 1.40 , weight1 + " not in [-1.6, -1.4]" )
58+ assert(weight2 >= - 1.0e-3 && weight2 <= 1.0e-3 , weight2 + " not in [-0.001, 0.001]" )
5759
5860 val validationData = LinearDataGenerator .generateLinearInput(A , Array [Double ](B ,C ), nPoints, 17 )
61+ .map { case LabeledPoint (label, features) =>
62+ LabeledPoint (label, Vectors .dense(1.0 +: features.toArray))
63+ }
5964 val validationRDD = sc.parallelize(validationData, 2 )
6065
6166 // Test prediction on RDD.
@@ -73,25 +78,32 @@ class LassoSuite extends FunSuite with LocalSparkContext {
7378 val C = 1.0e-2
7479
7580 val testData = LinearDataGenerator .generateLinearInput(A , Array [Double ](B , C ), nPoints, 42 )
81+ .map { case LabeledPoint (label, features) =>
82+ LabeledPoint (label, Vectors .dense(1.0 +: features.toArray))
83+ }
7684
85+ val initialA = - 1.0
7786 val initialB = - 1.0
7887 val initialC = - 1.0
79- val initialWeights = Vectors .dense(initialB, initialC)
88+ val initialWeights = Vectors .dense(initialA, initialB, initialC)
8089
81- val testRDD = sc.parallelize(testData, 2 )
82- testRDD.cache()
90+ val testRDD = sc.parallelize(testData, 2 ).cache()
8391
8492 val ls = new LassoWithSGD ()
8593 ls.optimizer.setStepSize(1.0 ).setRegParam(0.01 ).setNumIterations(40 )
8694
8795 val model = ls.run(testRDD, initialWeights)
8896 val weight0 = model.weights(0 )
8997 val weight1 = model.weights(1 )
90- assert(model.intercept >= 1.9 && model.intercept <= 2.1 , model.intercept + " not in [1.9, 2.1]" )
91- assert(weight0 >= - 1.60 && weight0 <= - 1.40 , weight0 + " not in [-1.6, -1.4]" )
92- assert(weight1 >= - 1.0e-3 && weight1 <= 1.0e-3 , weight1 + " not in [-0.001, 0.001]" )
98+ val weight2 = model.weights(2 )
99+ assert(weight0 >= 1.9 && weight0 <= 2.1 , weight0 + " not in [1.9, 2.1]" )
100+ assert(weight1 >= - 1.60 && weight1 <= - 1.40 , weight1 + " not in [-1.6, -1.4]" )
101+ assert(weight2 >= - 1.0e-3 && weight2 <= 1.0e-3 , weight2 + " not in [-0.001, 0.001]" )
93102
94103 val validationData = LinearDataGenerator .generateLinearInput(A , Array [Double ](B ,C ), nPoints, 17 )
104+ .map { case LabeledPoint (label, features) =>
105+ LabeledPoint (label, Vectors .dense(1.0 +: features.toArray))
106+ }
95107 val validationRDD = sc.parallelize(validationData,2 )
96108
97109 // Test prediction on RDD.
0 commit comments