Skip to content

Commit 548e9de

Browse files
committed
minor revision
1 parent 86522c4 commit 548e9de

File tree

2 files changed

+35
-36
lines changed

2 files changed

+35
-36
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.mllib.rdd
1919

20-
import breeze.linalg.{axpy, Vector => BV}
20+
import breeze.linalg.{Vector => BV, DenseVector => BDV}
2121

2222
import org.apache.spark.mllib.linalg.{Vectors, Vector}
2323
import org.apache.spark.rdd.RDD
@@ -29,60 +29,59 @@ import org.apache.spark.rdd.RDD
2929
trait VectorRDDStatisticalSummary {
3030
def mean: Vector
3131
def variance: Vector
32-
def totalCount: Long
32+
def count: Long
3333
def numNonZeros: Vector
3434
def max: Vector
3535
def min: Vector
3636
}
3737

3838
/**
3939
* Aggregates [[org.apache.spark.mllib.rdd.VectorRDDStatisticalSummary VectorRDDStatisticalSummary]]
40-
* together with add() and merge() function.
40+
* together with add() and merge() function. Online variance solution used in add() function, while
41+
* parallel variance solution used in merge() function. Reference here:
42+
* [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]. Solution here
43+
* ignoring the zero elements when calling add() and merge(), for decreasing the O(n) algorithm to
44+
* O(nnz). Real variance is computed here after we get other statistics, simply by another parallel
45+
* combination process.
4146
*/
42-
private class Aggregator(
43-
val currMean: BV[Double],
44-
val currM2n: BV[Double],
47+
private class VectorRDDStatisticsAggregator(
48+
val currMean: BDV[Double],
49+
val currM2n: BDV[Double],
4550
var totalCnt: Double,
46-
val nnz: BV[Double],
47-
val currMax: BV[Double],
48-
val currMin: BV[Double]) extends VectorRDDStatisticalSummary with Serializable {
51+
val nnz: BDV[Double],
52+
val currMax: BDV[Double],
53+
val currMin: BDV[Double]) extends VectorRDDStatisticalSummary with Serializable {
4954

5055
// lazy val is used for computing only once time. Same below.
5156
override lazy val mean = Vectors.fromBreeze(currMean :* nnz :/ totalCnt)
5257

53-
// Online variance solution used in add() function, while parallel variance solution used in
54-
// merge() function. Reference here:
55-
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
56-
// Solution here ignoring the zero elements when calling add() and merge(), for decreasing the
57-
// O(n) algorithm to O(nnz). Real variance is computed here after we get other statistics, simply
58-
// by another parallel combination process.
5958
override lazy val variance = {
6059
val deltaMean = currMean
6160
var i = 0
62-
while(i < currM2n.size) {
63-
currM2n(i) += deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt-nnz(i)) / totalCnt
61+
while (i < currM2n.size) {
62+
currM2n(i) += deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
6463
currM2n(i) /= totalCnt
6564
i += 1
6665
}
6766
Vectors.fromBreeze(currM2n)
6867
}
6968

70-
override lazy val totalCount: Long = totalCnt.toLong
69+
override lazy val count: Long = totalCnt.toLong
7170

7271
override lazy val numNonZeros: Vector = Vectors.fromBreeze(nnz)
7372

7473
override lazy val max: Vector = {
7574
nnz.iterator.foreach {
7675
case (id, count) =>
77-
if ((count == 0.0) || ((count < totalCnt) && (currMax(id) < 0.0))) currMax(id) = 0.0
76+
if ((count < totalCnt) && (currMax(id) < 0.0)) currMax(id) = 0.0
7877
}
7978
Vectors.fromBreeze(currMax)
8079
}
8180

8281
override lazy val min: Vector = {
8382
nnz.iterator.foreach {
8483
case (id, count) =>
85-
if ((count == 0.0) || ((count < totalCnt) && (currMin(id) > 0.0))) currMin(id) = 0.0
84+
if ((count < totalCnt) && (currMin(id) > 0.0)) currMin(id) = 0.0
8685
}
8786
Vectors.fromBreeze(currMin)
8887
}
@@ -92,7 +91,7 @@ private class Aggregator(
9291
*/
9392
def add(currData: BV[Double]): this.type = {
9493
currData.activeIterator.foreach {
95-
// this case is used for filtering the zero elements if the vector is a dense one.
94+
// this case is used for filtering the zero elements if the vector.
9695
case (id, 0.0) =>
9796
case (id, value) =>
9897
if (currMax(id) < value) currMax(id) = value
@@ -112,7 +111,7 @@ private class Aggregator(
112111
/**
113112
* Combine function used for combining intermediate results together from every worker.
114113
*/
115-
def merge(other: Aggregator): this.type = {
114+
def merge(other: VectorRDDStatisticsAggregator): this.type = {
116115

117116
totalCnt += other.totalCnt
118117

@@ -145,7 +144,7 @@ private class Aggregator(
145144
if (currMin(id) > value) currMin(id) = value
146145
}
147146

148-
axpy(1.0, other.nnz, nnz)
147+
nnz += other.nnz
149148
this
150149
}
151150
}
@@ -160,18 +159,18 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
160159
/**
161160
* Compute full column-wise statistics for the RDD with the size of Vector as input parameter.
162161
*/
163-
def summarizeStatistics(): VectorRDDStatisticalSummary = {
164-
val size = self.take(1).head.size
162+
def computeSummaryStatistics(): VectorRDDStatisticalSummary = {
163+
val size = self.first().size
165164

166-
val zeroValue = new Aggregator(
167-
BV.zeros[Double](size),
168-
BV.zeros[Double](size),
165+
val zeroValue = new VectorRDDStatisticsAggregator(
166+
BDV.zeros[Double](size),
167+
BDV.zeros[Double](size),
169168
0.0,
170-
BV.zeros[Double](size),
171-
BV.fill(size)(Double.MinValue),
172-
BV.fill(size)(Double.MaxValue))
169+
BDV.zeros[Double](size),
170+
BDV.fill(size)(Double.MinValue),
171+
BDV.fill(size)(Double.MaxValue))
173172

174-
self.map(_.toBreeze).aggregate[Aggregator](zeroValue)(
173+
self.map(_.toBreeze).aggregate[VectorRDDStatisticsAggregator](zeroValue)(
175174
(aggregator, data) => aggregator.add(data),
176175
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
177176
)

mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
4545

4646
test("dense statistical summary") {
4747
val data = sc.parallelize(localData, 2)
48-
val summary = data.summarizeStatistics()
48+
val summary = data.computeSummaryStatistics()
4949

5050
assert(equivVector(summary.mean, Vectors.dense(4.0, 5.0, 6.0)),
5151
"Dense column mean do not match.")
5252

5353
assert(equivVector(summary.variance, Vectors.dense(6.0, 6.0, 6.0)),
5454
"Dense column variance do not match.")
5555

56-
assert(summary.totalCount === 3, "Dense column cnt do not match.")
56+
assert(summary.count === 3, "Dense column cnt do not match.")
5757

5858
assert(equivVector(summary.numNonZeros, Vectors.dense(3.0, 3.0, 3.0)),
5959
"Dense column nnz do not match.")
@@ -67,15 +67,15 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
6767

6868
test("sparse statistical summary") {
6969
val dataForSparse = sc.parallelize(sparseData.toSeq, 2)
70-
val summary = dataForSparse.summarizeStatistics()
70+
val summary = dataForSparse.computeSummaryStatistics()
7171

7272
assert(equivVector(summary.mean, Vectors.dense(0.06, 0.05, 0.0)),
7373
"Sparse column mean do not match.")
7474

7575
assert(equivVector(summary.variance, Vectors.dense(0.2564, 0.2475, 0.0)),
7676
"Sparse column variance do not match.")
7777

78-
assert(summary.totalCount === 100, "Sparse column cnt do not match.")
78+
assert(summary.count === 100, "Sparse column cnt do not match.")
7979

8080
assert(equivVector(summary.numNonZeros, Vectors.dense(2.0, 1.0, 0.0)),
8181
"Sparse column nnz do not match.")

0 commit comments

Comments
 (0)