Skip to content

Commit 834ada2

Browse files
committed
optimized MLUtils.computeStats
update some ml algorithms to use Vector (cont.)
1 parent 135ab72 commit 834ada2

File tree

7 files changed

+84
-81
lines changed

7 files changed

+84
-81
lines changed

mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.rdd.RDD
2222

2323
import scala.collection.mutable.ArrayBuffer
2424

25-
import org.apache.spark.mllib.linalg.Vector
25+
import org.apache.spark.mllib.linalg.{Vectors, Vector}
2626

2727
/**
2828
* Class used to solve an optimization problem using Gradient Descent.
@@ -139,47 +139,46 @@ object GradientDescent extends Logging {
139139
numIterations: Int,
140140
regParam: Double,
141141
miniBatchFraction: Double,
142-
initialWeights: Vector): (Vector, Vector) = {
142+
initialWeights: Vector): (Vector, Array[Double]) = {
143143

144144
val stochasticLossHistory = new ArrayBuffer[Double](numIterations)
145145

146146
val nexamples: Long = data.count()
147147
val miniBatchSize = nexamples * miniBatchFraction
148148

149149
// Initialize weights as a column vector
150-
var weights = initialWeights.toBreeze.toDenseVector
150+
var weights = Vectors.dense(initialWeights.toArray)
151151

152152
/**
153153
* For the first iteration, the regVal will be initialized as sum of sqrt of
154154
* weights if it's L2 update; for L1 update; the same logic is followed.
155155
*/
156156
var regVal = updater.compute(
157-
weights, new DoubleMatrix(initialWeights.length, 1), 0, 1, regParam)._2
157+
weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
158158

159159
for (i <- 1 to numIterations) {
160160
// Sample a subset (fraction miniBatchFraction) of the total data
161161
// compute and sum up the subgradients on this subset (this is one map-reduce)
162162
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i).map {
163163
case (y, features) =>
164-
val featuresCol = new DoubleMatrix(features.length, 1, features:_*)
165-
val (grad, loss) = gradient.compute(featuresCol, y, weights)
166-
(grad, loss)
167-
}.reduce((a, b) => (a._1.addi(b._1), a._2 + b._2))
164+
val (grad, loss) = gradient.compute(features, y, weights)
165+
(grad.toBreeze, loss)
166+
}.reduce((a, b) => (a._1 += b._1, a._2 + b._2))
168167

169168
/**
170169
* NOTE(Xinghao): lossSum is computed using the weights from the previous iteration
171170
* and regVal is the regularization value computed in the previous iteration as well.
172171
*/
173172
stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
174173
val update = updater.compute(
175-
weights, gradientSum.div(miniBatchSize), stepSize, i, regParam)
174+
weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam)
176175
weights = update._1
177176
regVal = update._2
178177
}
179178

180179
logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format(
181180
stochasticLossHistory.takeRight(10).mkString(", ")))
182181

183-
(weights.toArray, stochasticLossHistory.toArray)
182+
(weights, stochasticLossHistory.toArray)
184183
}
185184
}

mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.mllib.optimization
1919

2020
import scala.math._
21-
import org.jblas.DoubleMatrix
2221

2322
import breeze.linalg.{norm => brzNorm}
2423

@@ -122,7 +121,7 @@ class SquaredL2Updater extends Updater {
122121
gradient: Vector,
123122
stepSize: Double,
124123
iter: Int,
125-
regParam: Double): (DoubleMatrix, Double) = {
124+
regParam: Double): (Vector, Double) = {
126125
// add up both updates from the gradient of the loss (= step) as well as
127126
// the gradient of the regularizer (= regParam * weightsOld)
128127
// w' = w - thisIterStepSize * (gradient + regParam * w)
@@ -132,7 +131,7 @@ class SquaredL2Updater extends Updater {
132131
(gradient.toBreeze * thisIterStepSize)
133132
val norm = brzNorm(brzWeights, 2.0)
134133

135-
(Vectors.fromBreeze(newWeights), 0.5 * regParam * norm * norm)
134+
(Vectors.fromBreeze(brzWeights), 0.5 * regParam * norm * norm)
136135
}
137136
}
138137

mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
118118

119119
/** Prepends one to the input vector. */
120120
private def prependOne(vector: Vector): Vector = {
121-
val vectorWithIntercept = vector match {
122-
case dv: BDV[Double] => BDV.vertcat(BDV.ones(1), dv)
121+
val vectorWithIntercept = vector.toBreeze match {
122+
case dv: BDV[Double] => BDV.vertcat(BDV.ones[Double](1), dv)
123123
case sv: BSV[Double] => BSV.vertcat(new BSV[Double](Array(0), Array(1.0), 1), sv)
124124
case v: Any => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
125125
}

mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ import org.apache.spark.SparkContext
2121
import org.apache.spark.rdd.RDD
2222
import org.apache.spark.mllib.optimization._
2323
import org.apache.spark.mllib.util.MLUtils
24-
25-
import org.jblas.DoubleMatrix
24+
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2625

2726
/**
2827
* Regression model trained using LinearRegression.
@@ -31,15 +30,15 @@ import org.jblas.DoubleMatrix
3130
* @param intercept Intercept computed for this model.
3231
*/
3332
class LinearRegressionModel(
34-
override val weights: Array[Double],
33+
override val weights: Vector,
3534
override val intercept: Double)
3635
extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable {
3736

3837
override protected def predictPoint(
39-
dataMatrix: DoubleMatrix,
40-
weightMatrix: DoubleMatrix,
38+
dataMatrix: Vector,
39+
weightMatrix: Vector,
4140
intercept: Double): Double = {
42-
dataMatrix.dot(weightMatrix) + intercept
41+
weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
4342
}
4443
}
4544

@@ -69,7 +68,7 @@ class LinearRegressionWithSGD private (
6968
*/
7069
def this() = this(1.0, 100, 1.0)
7170

72-
override protected def createModel(weights: Array[Double], intercept: Double) = {
71+
override protected def createModel(weights: Vector, intercept: Double) = {
7372
new LinearRegressionModel(weights, intercept)
7473
}
7574
}
@@ -98,7 +97,7 @@ object LinearRegressionWithSGD {
9897
numIterations: Int,
9998
stepSize: Double,
10099
miniBatchFraction: Double,
101-
initialWeights: Array[Double])
100+
initialWeights: Vector)
102101
: LinearRegressionModel =
103102
{
104103
new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input,
@@ -172,7 +171,7 @@ object LinearRegressionWithSGD {
172171
val sc = new SparkContext(args(0), "LinearRegression")
173172
val data = MLUtils.loadLabeledData(sc, args(1))
174173
val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble)
175-
println("Weights: " + model.weights.mkString("[", ", ", "]"))
174+
println("Weights: " + model.weights)
176175
println("Intercept: " + model.intercept)
177176

178177
sc.stop()

mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ import org.apache.spark.SparkContext
2121
import org.apache.spark.rdd.RDD
2222
import org.apache.spark.mllib.optimization._
2323
import org.apache.spark.mllib.util.MLUtils
24-
import org.apache.spark.mllib.linalg.Vector
24+
import org.apache.spark.mllib.linalg.{Vectors, Vector}
25+
26+
import breeze.linalg.{Vector => BV, DenseVector => BDV}
2527

2628
/**
2729
* Regression model trained using RidgeRegression.
@@ -36,10 +38,10 @@ class RidgeRegressionModel(
3638
with RegressionModel with Serializable {
3739

3840
override protected def predictPoint(
39-
dataMatrix: DoubleMatrix,
40-
weightMatrix: DoubleMatrix,
41+
dataMatrix: Vector,
42+
weightMatrix: Vector,
4143
intercept: Double): Double = {
42-
dataMatrix.dot(weightMatrix) + intercept
44+
weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
4345
}
4446
}
4547

@@ -71,8 +73,8 @@ class RidgeRegressionWithSGD private (
7173
super.setIntercept(false)
7274

7375
var yMean = 0.0
74-
var xColMean: DoubleMatrix = _
75-
var xColSd: DoubleMatrix = _
76+
var xColMean: BV[Double] = _
77+
var xColSd: BV[Double] = _
7678

7779
/**
7880
* Construct a RidgeRegression object with default parameters
@@ -85,33 +87,33 @@ class RidgeRegressionWithSGD private (
8587
this
8688
}
8789

88-
override protected def createModel(weights: Array[Double], intercept: Double) = {
89-
val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
90-
val weightsScaled = weightsMat.div(xColSd)
91-
val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)
90+
override protected def createModel(weights: Vector, intercept: Double) = {
91+
val weightsMat = weights.toBreeze
92+
val weightsScaled = weightsMat :/ xColSd
93+
val interceptScaled = yMean - weightsMat.dot(xColMean :/ xColSd)
9294

93-
new RidgeRegressionModel(weightsScaled.data, interceptScaled)
95+
new RidgeRegressionModel(Vectors.fromBreeze(weightsScaled), interceptScaled)
9496
}
9597

9698
override def run(
9799
input: RDD[LabeledPoint],
98-
initialWeights: Array[Double])
100+
initialWeights: Vector)
99101
: RidgeRegressionModel =
100102
{
101-
val nfeatures: Int = input.first().features.length
103+
val nfeatures: Int = input.first().features.size
102104
val nexamples: Long = input.count()
103105

104106
// To avoid penalizing the intercept, we center and scale the data.
105107
val stats = MLUtils.computeStats(input, nfeatures, nexamples)
106108
yMean = stats._1
107-
xColMean = stats._2
108-
xColSd = stats._3
109+
xColMean = stats._2.toBreeze
110+
xColSd = stats._3.toBreeze
109111

110112
val normalizedData = input.map { point =>
111113
val yNormalized = point.label - yMean
112-
val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*)
113-
val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd)
114-
LabeledPoint(yNormalized, featuresNormalized.toArray)
114+
val featuresMat = point.features.toBreeze
115+
val featuresNormalized = (featuresMat - xColMean) :/ xColSd
116+
LabeledPoint(yNormalized, Vectors.fromBreeze(featuresNormalized))
115117
}
116118

117119
super.run(normalizedData, initialWeights)
@@ -143,7 +145,7 @@ object RidgeRegressionWithSGD {
143145
stepSize: Double,
144146
regParam: Double,
145147
miniBatchFraction: Double,
146-
initialWeights: Array[Double])
148+
initialWeights: Vector)
147149
: RidgeRegressionModel =
148150
{
149151
new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(
@@ -220,7 +222,8 @@ object RidgeRegressionWithSGD {
220222
val data = MLUtils.loadLabeledData(sc, args(1))
221223
val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble,
222224
args(3).toDouble)
223-
println("Weights: " + model.weights.mkString("[", ", ", "]"))
225+
226+
println("Weights: " + model.weights)
224227
println("Intercept: " + model.intercept)
225228

226229
sc.stop()

mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,13 @@
1717

1818
package org.apache.spark.mllib.util
1919

20+
import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV,
21+
squaredDistance => breezeSquaredDistance}
22+
2023
import org.apache.spark.SparkContext
2124
import org.apache.spark.rdd.RDD
22-
import org.apache.spark.SparkContext._
23-
24-
import org.jblas.DoubleMatrix
25-
2625
import org.apache.spark.mllib.regression.LabeledPoint
27-
28-
import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance}
26+
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2927

3028
/**
3129
* Helper methods to load, save and pre-process data used in ML Lib.
@@ -54,7 +52,7 @@ object MLUtils {
5452
sc.textFile(dir).map { line =>
5553
val parts = line.split(',')
5654
val label = parts(0).toDouble
57-
val features = parts(1).trim().split(' ').map(_.toDouble)
55+
val features = Vectors.dense(parts(1).trim().split(' ').map(_.toDouble))
5856
LabeledPoint(label, features)
5957
}
6058
}
@@ -68,52 +66,59 @@ object MLUtils {
6866
* @param dir Directory to save the data.
6967
*/
7068
def saveLabeledData(data: RDD[LabeledPoint], dir: String) {
71-
val dataStr = data.map(x => x.label + "," + x.features.mkString(" "))
69+
val dataStr = data.map(x => x.label + "," + x.features.toArray.mkString(" "))
7270
dataStr.saveAsTextFile(dir)
7371
}
7472

7573
/**
7674
* Utility function to compute mean and standard deviation on a given dataset.
7775
*
7876
* @param data - input data set whose statistics are computed
79-
* @param nfeatures - number of features
80-
* @param nexamples - number of examples in input dataset
77+
* @param numFeatures - number of features
78+
* @param numExamples - number of examples in input dataset
8179
*
8280
* @return (yMean, xColMean, xColSd) - Tuple consisting of
8381
* yMean - mean of the labels
8482
* xColMean - Row vector with mean for every column (or feature) of the input data
8583
* xColSd - Row vector standard deviation for every column (or feature) of the input data.
8684
*/
87-
def computeStats(data: RDD[LabeledPoint], nfeatures: Int, nexamples: Long):
88-
(Double, DoubleMatrix, DoubleMatrix) = {
89-
val yMean: Double = data.map { labeledPoint => labeledPoint.label }.reduce(_ + _) / nexamples
85+
def computeStats(data: RDD[LabeledPoint], numFeatures: Int, numExamples: Long)
86+
: (Double, Vector, Vector) = {
9087

91-
// NOTE: We shuffle X by column here to compute column sum and sum of squares.
92-
val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { labeledPoint =>
93-
val nCols = labeledPoint.features.length
94-
// Traverse over every column and emit (col, value, value^2)
95-
Iterator.tabulate(nCols) { i =>
96-
(i, (labeledPoint.features(i), labeledPoint.features(i)*labeledPoint.features(i)))
97-
}
98-
}.reduceByKey { case(x1, x2) =>
99-
(x1._1 + x2._1, x1._2 + x2._2)
88+
val brzData = data.map { case LabeledPoint(label, features) =>
89+
(label, features.toBreeze)
10090
}
101-
val xColSumsMap = xColSumSq.collectAsMap()
10291

103-
val xColMean = DoubleMatrix.zeros(nfeatures, 1)
104-
val xColSd = DoubleMatrix.zeros(nfeatures, 1)
92+
val aggStats = brzData.aggregate(
93+
(0L, 0.0, BDV.zeros[Double](numFeatures), BDV.zeros[Double](numFeatures))
94+
)(
95+
seqOp = (c , v) => (c, v) match {
96+
case ((n, sumLabel, sum, sumSq), (label, features)) =>
97+
features.activeIterator.foreach { case (i, x) =>
98+
sumSq(i) += x * x
99+
}
100+
(n + 1L, sumLabel + label, sum += features, sumSq)
101+
},
102+
combOp = (c1, c2) => (c1, c2) match {
103+
case ((n1, sumLabel1, sum1, sumSq1), (n2, sumLabel2, sum2, sumSq2)) =>
104+
(n1 + n2, sumLabel1 + sumLabel2, sum1 += sum2, sumSq1 += sumSq2)
105+
}
106+
)
107+
108+
val (nl, sumLabel, sum, sumSq) = aggStats
109+
require(nl > 0, "Input data is empty.")
105110

106-
// Compute mean and unbiased variance using column sums
107-
var col = 0
108-
while (col < nfeatures) {
109-
xColMean.put(col, xColSumsMap(col)._1 / nexamples)
110-
val variance =
111-
(xColSumsMap(col)._2 - (math.pow(xColSumsMap(col)._1, 2) / nexamples)) / nexamples
112-
xColSd.put(col, math.sqrt(variance))
113-
col += 1
111+
val n = nl.toDouble
112+
val yMean = sumLabel / n
113+
val mean: BDV[Double] = sum / n
114+
val std = new Array[Double](sum.length)
115+
var i = 0
116+
while (i < numFeatures) {
117+
std(i) = sumSq(i) / n - mean(i) * mean(i)
118+
i += 1
114119
}
115120

116-
(yMean, xColMean, xColSd)
121+
(yMean, Vectors.fromBreeze(mean), Vectors.dense(std))
117122
}
118123

119124
/**

mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,12 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20+
import org.scalatest.FunSuite
2021

2122
import org.jblas.DoubleMatrix
22-
import org.scalatest.BeforeAndAfterAll
23-
import org.scalatest.FunSuite
2423

2524
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
2625

27-
2826
class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
2927

3028
def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {

0 commit comments

Comments
 (0)