Skip to content

Commit fdfb45e

Browse files
yinxusenpwendell
authored andcommitted
[WIP] [SPARK-1328] Add vector statistics
As with the new vector system in MLlib, we find that it is good to add some new APIs to precess the `RDD[Vector]`. Beside, the former implementation of `computeStat` is not stable which could loss precision, and has the possibility to cause `Nan` in scientific computing, just as said in the [SPARK-1328](https://spark-project.atlassian.net/browse/SPARK-1328). APIs contain: * rowMeans(): RDD[Double] * rowNorm2(): RDD[Double] * rowSDs(): RDD[Double] * colMeans(): Vector * colMeans(size: Int): Vector * colNorm2(): Vector * colNorm2(size: Int): Vector * colSDs(): Vector * colSDs(size: Int): Vector * maxOption((Vector, Vector) => Boolean): Option[Vector] * minOption((Vector, Vector) => Boolean): Option[Vector] * rowShrink(): RDD[Vector] * colShrink(): RDD[Vector] This is working in process now, and some more APIs will add to `LabeledPoint`. Moreover, the implicit declaration will move from `MLUtils` to `MLContext` later. Author: Xusen Yin <yinxusen@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #268 from yinxusen/vector-statistics and squashes the following commits: d61363f [Xusen Yin] rebase to latest master 16ae684 [Xusen Yin] fix minor error and remove useless method 10cf5d3 [Xusen Yin] refine some return type b064714 [Xusen Yin] remove computeStat in MLUtils cbbefdb [Xiangrui Meng] update multivariate statistical summary interface and clean tests 4eaf28a [Xusen Yin] merge VectorRDDStatistics into RowMatrix 48ee053 [Xusen Yin] fix minor error e624f93 [Xusen Yin] fix scala style error 1fba230 [Xusen Yin] merge while loop together 69e1f37 [Xusen Yin] remove lazy eval, and minor memory footprint 548e9de [Xusen Yin] minor revision 86522c4 [Xusen Yin] add comments on functions dc77e38 [Xusen Yin] test sparse vector RDD 18cf072 [Xusen Yin] change def to lazy val to make sure that the computations in function be evaluated only once f7a3ca2 [Xusen Yin] fix the corner case of maxmin 967d041 [Xusen Yin] full revision with Aggregator class 138300c [Xusen Yin] add new Aggregator class 1376ff4 [Xusen Yin] rename variables and adjust code 4a5c38d [Xusen Yin] add scala doc, refine code and comments 036b7a5 [Xusen Yin] fix the bug of Nan occur f6e8e9a [Xusen Yin] add sparse vectors test 4cfbadf [Xusen Yin] fix bug of min max 4e4fbd1 [Xusen Yin] separate seqop and combop out as independent functions a6d5a2e [Xusen Yin] rewrite for only computing non-zero elements 3980287 [Xusen Yin] rename variables 62a2c3e [Xusen Yin] use axpy and in-place if possible 9a75ebd [Xusen Yin] add case class to wrap return values d816ac7 [Xusen Yin] remove useless APIs c4651bb [Xusen Yin] remove row-wise APIs and refine code 1338ea1 [Xusen Yin] all-in-one version test passed cc65810 [Xusen Yin] add parallel mean and variance 9af2e95 [Xusen Yin] refine the code style ad6c82d [Xusen Yin] add shrink test e09d5d2 [Xusen Yin] add scala docs and refine shrink method 8ef3377 [Xusen Yin] pass all tests 28cf060 [Xusen Yin] fix error of column means 54b19ab [Xusen Yin] add new API to shrink RDD[Vector] 8c6c0e1 [Xusen Yin] add basic statistics
1 parent 7038b00 commit fdfb45e

File tree

5 files changed

+230
-76
lines changed

5 files changed

+230
-76
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala

Lines changed: 157 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,146 @@ package org.apache.spark.mllib.linalg.distributed
1919

2020
import java.util
2121

22-
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd}
22+
import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd}
2323
import breeze.numerics.{sqrt => brzSqrt}
2424
import com.github.fommil.netlib.BLAS.{getInstance => blas}
2525

2626
import org.apache.spark.annotation.Experimental
2727
import org.apache.spark.mllib.linalg._
2828
import org.apache.spark.rdd.RDD
2929
import org.apache.spark.Logging
30+
import org.apache.spark.mllib.stat.MultivariateStatisticalSummary
31+
32+
/**
33+
* Column statistics aggregator implementing
34+
* [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary]]
35+
* together with add() and merge() function.
36+
* A numerically stable algorithm is implemented to compute sample mean and variance:
37+
*[[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]].
38+
* Zero elements (including explicit zero values) are skipped when calling add() and merge(),
39+
* to have time complexity O(nnz) instead of O(n) for each column.
40+
*/
41+
private class ColumnStatisticsAggregator(private val n: Int)
42+
extends MultivariateStatisticalSummary with Serializable {
43+
44+
private val currMean: BDV[Double] = BDV.zeros[Double](n)
45+
private val currM2n: BDV[Double] = BDV.zeros[Double](n)
46+
private var totalCnt = 0.0
47+
private val nnz: BDV[Double] = BDV.zeros[Double](n)
48+
private val currMax: BDV[Double] = BDV.fill(n)(Double.MinValue)
49+
private val currMin: BDV[Double] = BDV.fill(n)(Double.MaxValue)
50+
51+
override def mean: Vector = {
52+
val realMean = BDV.zeros[Double](n)
53+
var i = 0
54+
while (i < n) {
55+
realMean(i) = currMean(i) * nnz(i) / totalCnt
56+
i += 1
57+
}
58+
Vectors.fromBreeze(realMean)
59+
}
60+
61+
override def variance: Vector = {
62+
val realVariance = BDV.zeros[Double](n)
63+
64+
val denominator = totalCnt - 1.0
65+
66+
// Sample variance is computed, if the denominator is less than 0, the variance is just 0.
67+
if (denominator > 0.0) {
68+
val deltaMean = currMean
69+
var i = 0
70+
while (i < currM2n.size) {
71+
realVariance(i) =
72+
currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
73+
realVariance(i) /= denominator
74+
i += 1
75+
}
76+
}
77+
78+
Vectors.fromBreeze(realVariance)
79+
}
80+
81+
override def count: Long = totalCnt.toLong
82+
83+
override def numNonzeros: Vector = Vectors.fromBreeze(nnz)
84+
85+
override def max: Vector = {
86+
var i = 0
87+
while (i < n) {
88+
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
89+
i += 1
90+
}
91+
Vectors.fromBreeze(currMax)
92+
}
93+
94+
override def min: Vector = {
95+
var i = 0
96+
while (i < n) {
97+
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
98+
i += 1
99+
}
100+
Vectors.fromBreeze(currMin)
101+
}
102+
103+
/**
104+
* Aggregates a row.
105+
*/
106+
def add(currData: BV[Double]): this.type = {
107+
currData.activeIterator.foreach {
108+
case (_, 0.0) => // Skip explicit zero elements.
109+
case (i, value) =>
110+
if (currMax(i) < value) {
111+
currMax(i) = value
112+
}
113+
if (currMin(i) > value) {
114+
currMin(i) = value
115+
}
116+
117+
val tmpPrevMean = currMean(i)
118+
currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
119+
currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
120+
121+
nnz(i) += 1.0
122+
}
123+
124+
totalCnt += 1.0
125+
this
126+
}
127+
128+
/**
129+
* Merges another aggregator.
130+
*/
131+
def merge(other: ColumnStatisticsAggregator): this.type = {
132+
require(n == other.n, s"Dimensions mismatch. Expecting $n but got ${other.n}.")
133+
134+
totalCnt += other.totalCnt
135+
val deltaMean = currMean - other.currMean
136+
137+
var i = 0
138+
while (i < n) {
139+
// merge mean together
140+
if (other.currMean(i) != 0.0) {
141+
currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
142+
(nnz(i) + other.nnz(i))
143+
}
144+
// merge m2n together
145+
if (nnz(i) + other.nnz(i) != 0.0) {
146+
currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
147+
(nnz(i) + other.nnz(i))
148+
}
149+
if (currMax(i) < other.currMax(i)) {
150+
currMax(i) = other.currMax(i)
151+
}
152+
if (currMin(i) > other.currMin(i)) {
153+
currMin(i) = other.currMin(i)
154+
}
155+
i += 1
156+
}
157+
158+
nnz += other.nnz
159+
this
160+
}
161+
}
30162

31163
/**
32164
* :: Experimental ::
@@ -182,13 +314,7 @@ class RowMatrix(
182314
combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2)
183315
)
184316

185-
// Update _m if it is not set, or verify its value.
186-
if (nRows <= 0L) {
187-
nRows = m
188-
} else {
189-
require(nRows == m,
190-
s"The number of rows $m is different from what specified or previously computed: ${nRows}.")
191-
}
317+
updateNumRows(m)
192318

193319
mean :/= m.toDouble
194320

@@ -240,6 +366,19 @@ class RowMatrix(
240366
}
241367
}
242368

369+
/**
370+
* Computes column-wise summary statistics.
371+
*/
372+
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
373+
val zeroValue = new ColumnStatisticsAggregator(numCols().toInt)
374+
val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)(
375+
(aggregator, data) => aggregator.add(data),
376+
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
377+
)
378+
updateNumRows(summary.count)
379+
summary
380+
}
381+
243382
/**
244383
* Multiply this matrix by a local matrix on the right.
245384
*
@@ -276,6 +415,16 @@ class RowMatrix(
276415
}
277416
mat
278417
}
418+
419+
/** Updates or verfires the number of rows. */
420+
private def updateNumRows(m: Long) {
421+
if (nRows <= 0) {
422+
nRows == m
423+
} else {
424+
require(nRows == m,
425+
s"The number of rows $m is different from what specified or previously computed: ${nRows}.")
426+
}
427+
}
279428
}
280429

281430
object RowMatrix {
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.stat
19+
20+
import org.apache.spark.mllib.linalg.Vector
21+
22+
/**
23+
* Trait for multivariate statistical summary of a data matrix.
24+
*/
25+
trait MultivariateStatisticalSummary {
26+
27+
/**
28+
* Sample mean vector.
29+
*/
30+
def mean: Vector
31+
32+
/**
33+
* Sample variance vector. Should return a zero vector if the sample size is 1.
34+
*/
35+
def variance: Vector
36+
37+
/**
38+
* Sample size.
39+
*/
40+
def count: Long
41+
42+
/**
43+
* Number of nonzero elements (including explicitly presented zero values) in each column.
44+
*/
45+
def numNonzeros: Vector
46+
47+
/**
48+
* Maximum value of each column.
49+
*/
50+
def max: Vector
51+
52+
/**
53+
* Minimum value of each column.
54+
*/
55+
def min: Vector
56+
}

mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala

Lines changed: 2 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717

1818
package org.apache.spark.mllib.util
1919

20-
import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV,
21-
squaredDistance => breezeSquaredDistance}
20+
import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance}
2221

2322
import org.apache.spark.annotation.Experimental
2423
import org.apache.spark.SparkContext
2524
import org.apache.spark.rdd.RDD
2625
import org.apache.spark.mllib.regression.LabeledPoint
27-
import org.apache.spark.mllib.linalg.{Vector, Vectors}
26+
import org.apache.spark.mllib.linalg.Vectors
2827

2928
/**
3029
* Helper methods to load, save and pre-process data used in ML Lib.
@@ -158,58 +157,6 @@ object MLUtils {
158157
dataStr.saveAsTextFile(dir)
159158
}
160159

161-
/**
162-
* Utility function to compute mean and standard deviation on a given dataset.
163-
*
164-
* @param data - input data set whose statistics are computed
165-
* @param numFeatures - number of features
166-
* @param numExamples - number of examples in input dataset
167-
*
168-
* @return (yMean, xColMean, xColSd) - Tuple consisting of
169-
* yMean - mean of the labels
170-
* xColMean - Row vector with mean for every column (or feature) of the input data
171-
* xColSd - Row vector standard deviation for every column (or feature) of the input data.
172-
*/
173-
private[mllib] def computeStats(
174-
data: RDD[LabeledPoint],
175-
numFeatures: Int,
176-
numExamples: Long): (Double, Vector, Vector) = {
177-
val brzData = data.map { case LabeledPoint(label, features) =>
178-
(label, features.toBreeze)
179-
}
180-
val aggStats = brzData.aggregate(
181-
(0L, 0.0, BDV.zeros[Double](numFeatures), BDV.zeros[Double](numFeatures))
182-
)(
183-
seqOp = (c, v) => (c, v) match {
184-
case ((n, sumLabel, sum, sumSq), (label, features)) =>
185-
features.activeIterator.foreach { case (i, x) =>
186-
sumSq(i) += x * x
187-
}
188-
(n + 1L, sumLabel + label, sum += features, sumSq)
189-
},
190-
combOp = (c1, c2) => (c1, c2) match {
191-
case ((n1, sumLabel1, sum1, sumSq1), (n2, sumLabel2, sum2, sumSq2)) =>
192-
(n1 + n2, sumLabel1 + sumLabel2, sum1 += sum2, sumSq1 += sumSq2)
193-
}
194-
)
195-
val (nl, sumLabel, sum, sumSq) = aggStats
196-
197-
require(nl > 0, "Input data is empty.")
198-
require(nl == numExamples)
199-
200-
val n = nl.toDouble
201-
val yMean = sumLabel / n
202-
val mean = sum / n
203-
val std = new Array[Double](sum.length)
204-
var i = 0
205-
while (i < numFeatures) {
206-
std(i) = sumSq(i) / n - mean(i) * mean(i)
207-
i += 1
208-
}
209-
210-
(yMean, Vectors.fromBreeze(mean), Vectors.dense(std))
211-
}
212-
213160
/**
214161
* Returns the squared Euclidean distance between two vectors. The following formula will be used
215162
* if it does not introduce too much numerical error:

mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,4 +170,19 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext {
170170
))
171171
}
172172
}
173+
174+
test("compute column summary statistics") {
175+
for (mat <- Seq(denseMat, sparseMat)) {
176+
val summary = mat.computeColumnSummaryStatistics()
177+
// Run twice to make sure no internal states are changed.
178+
for (k <- 0 to 1) {
179+
assert(summary.mean === Vectors.dense(4.5, 3.0, 4.0), "mean mismatch")
180+
assert(summary.variance === Vectors.dense(15.0, 10.0, 10.0), "variance mismatch")
181+
assert(summary.count === m, "count mismatch.")
182+
assert(summary.numNonzeros === Vectors.dense(3.0, 3.0, 4.0), "nnz mismatch")
183+
assert(summary.max === Vectors.dense(9.0, 7.0, 8.0), "max mismatch")
184+
assert(summary.min === Vectors.dense(0.0, 0.0, 1.0), "column mismatch.")
185+
}
186+
}
187+
}
173188
}

mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import com.google.common.base.Charsets
2727
import com.google.common.io.Files
2828

2929
import org.apache.spark.mllib.linalg.Vectors
30-
import org.apache.spark.mllib.regression.LabeledPoint
3130
import org.apache.spark.mllib.util.MLUtils._
3231

3332
class MLUtilsSuite extends FunSuite with LocalSparkContext {
@@ -56,18 +55,6 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
5655
}
5756
}
5857

59-
test("compute stats") {
60-
val data = Seq.fill(3)(Seq(
61-
LabeledPoint(1.0, Vectors.dense(1.0, 2.0, 3.0)),
62-
LabeledPoint(0.0, Vectors.dense(3.0, 4.0, 5.0))
63-
)).flatten
64-
val rdd = sc.parallelize(data, 2)
65-
val (meanLabel, mean, std) = MLUtils.computeStats(rdd, 3, 6)
66-
assert(meanLabel === 0.5)
67-
assert(mean === Vectors.dense(2.0, 3.0, 4.0))
68-
assert(std === Vectors.dense(1.0, 1.0, 1.0))
69-
}
70-
7158
test("loadLibSVMData") {
7259
val lines =
7360
"""

0 commit comments

Comments
 (0)