Skip to content

Commit 4a5c38d

Browse files
committed
add scala doc, refine code and comments
1 parent 036b7a5 commit 4a5c38d

File tree

2 files changed

+56
-37
lines changed

2 files changed

+56
-37
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,15 @@
1616
*/
1717
package org.apache.spark.mllib.rdd
1818

19-
import breeze.linalg.{Vector => BV, axpy}
19+
import breeze.linalg.{axpy, Vector => BV}
2020

2121
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2222
import org.apache.spark.rdd.RDD
2323

24+
/**
25+
* Case class of the summary statistics, including mean, variance, count, max, min, and non-zero
26+
* elements count.
27+
*/
2428
case class VectorRDDStatisticalSummary(
2529
mean: Vector,
2630
variance: Vector,
@@ -29,6 +33,12 @@ case class VectorRDDStatisticalSummary(
2933
min: Vector,
3034
nonZeroCnt: Vector) extends Serializable
3135

36+
/**
37+
* Case class of the aggregate value for collecting summary statistics from RDD[Vector]. These
38+
* values are relatively with
39+
* [[org.apache.spark.mllib.rdd.VectorRDDStatisticalSummary VectorRDDStatisticalSummary]], the
40+
* latter is computed from the former.
41+
*/
3242
private case class VectorRDDStatisticalRing(
3343
fakeMean: BV[Double],
3444
fakeM2n: BV[Double],
@@ -45,18 +55,8 @@ private case class VectorRDDStatisticalRing(
4555
class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
4656

4757
/**
48-
* Compute full column-wise statistics for the RDD, including
49-
* {{{
50-
* Mean: Vector,
51-
* Variance: Vector,
52-
* Count: Double,
53-
* Non-zero count: Vector,
54-
* Maximum elements: Vector,
55-
* Minimum elements: Vector.
56-
* }}},
57-
* with the size of Vector as input parameter.
58+
* Aggregate function used for aggregating elements in a worker together.
5859
*/
59-
6060
private def seqOp(
6161
aggregator: VectorRDDStatisticalRing,
6262
currData: BV[Double]): VectorRDDStatisticalRing = {
@@ -84,6 +84,9 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
8484
}
8585
}
8686

87+
/**
88+
* Combine function used for combining intermediate results together from every worker.
89+
*/
8790
private def combOp(
8891
statistics1: VectorRDDStatisticalRing,
8992
statistics2: VectorRDDStatisticalRing): VectorRDDStatisticalRing = {
@@ -92,27 +95,38 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
9295
VectorRDDStatisticalRing(mean2, m2n2, cnt2, nnz2, max2, min2)) =>
9396
val totalCnt = cnt1 + cnt2
9497
val deltaMean = mean2 - mean1
98+
9599
mean2.activeIterator.foreach {
96100
case (id, 0.0) =>
97-
case (id, value) => mean1(id) = (mean1(id) * nnz1(id) + mean2(id) * nnz2(id)) / (nnz1(id) + nnz2(id))
101+
case (id, value) =>
102+
mean1(id) = (mean1(id) * nnz1(id) + mean2(id) * nnz2(id)) / (nnz1(id) + nnz2(id))
98103
}
104+
99105
m2n2.activeIterator.foreach {
100106
case (id, 0.0) =>
101-
case (id, value) => m2n1(id) += value + deltaMean(id) * deltaMean(id) * nnz1(id) * nnz2(id) / (nnz1(id)+nnz2(id))
107+
case (id, value) =>
108+
m2n1(id) +=
109+
value + deltaMean(id) * deltaMean(id) * nnz1(id) * nnz2(id) / (nnz1(id)+nnz2(id))
102110
}
111+
103112
max2.activeIterator.foreach {
104113
case (id, value) =>
105114
if (max1(id) < value) max1(id) = value
106115
}
116+
107117
min2.activeIterator.foreach {
108118
case (id, value) =>
109119
if (min1(id) > value) min1(id) = value
110120
}
121+
111122
axpy(1.0, nnz2, nnz1)
112123
VectorRDDStatisticalRing(mean1, m2n1, totalCnt, nnz1, max1, min1)
113124
}
114125
}
115126

127+
/**
128+
* Compute full column-wise statistics for the RDD with the size of Vector as input parameter.
129+
*/
116130
def summarizeStatistics(size: Int): VectorRDDStatisticalSummary = {
117131
val zeroValue = VectorRDDStatisticalRing(
118132
BV.zeros[Double](size),
@@ -122,16 +136,17 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
122136
BV.fill(size)(Double.MinValue),
123137
BV.fill(size)(Double.MaxValue))
124138

125-
val breezeVectors = self.map(_.toBreeze)
126139
val VectorRDDStatisticalRing(fakeMean, fakeM2n, totalCnt, nnz, fakeMax, fakeMin) =
127-
breezeVectors.aggregate(zeroValue)(seqOp, combOp)
140+
self.map(_.toBreeze).aggregate(zeroValue)(seqOp, combOp)
128141

129142
// solve real mean
130143
val realMean = fakeMean :* nnz :/ totalCnt
131-
// solve real variance
132-
val deltaMean = fakeMean :- 0.0
133-
val realVar = fakeM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
134-
// max, min
144+
145+
// solve real m2n
146+
val deltaMean = fakeMean
147+
val realM2n = fakeM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
148+
149+
// remove the initial value in max and min, i.e. the Double.MaxValue or Double.MinValue.
135150
val max = Vectors.sparse(size, fakeMax.activeIterator.map { case (id, value) =>
136151
if ((value == Double.MinValue) && (realMean(id) != Double.MinValue)) (id, 0.0)
137152
else (id, value)
@@ -142,11 +157,11 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
142157
}.toSeq)
143158

144159
// get variance
145-
realVar :/= totalCnt
160+
realM2n :/= totalCnt
146161

147162
VectorRDDStatisticalSummary(
148163
Vectors.fromBreeze(realMean),
149-
Vectors.fromBreeze(realVar),
164+
Vectors.fromBreeze(realM2n),
150165
totalCnt.toLong,
151166
Vectors.fromBreeze(nnz),
152167
max,

mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,18 @@
1717

1818
package org.apache.spark.mllib.rdd
1919

20+
import scala.collection.mutable.ArrayBuffer
21+
2022
import org.scalatest.FunSuite
23+
2124
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2225
import org.apache.spark.mllib.util.LocalSparkContext
2326
import org.apache.spark.mllib.util.MLUtils._
24-
import scala.collection.mutable.ArrayBuffer
2527

28+
/**
29+
* Test suite for the summary statistics of RDD[Vector]. Both the accuracy and the time consuming
30+
* between dense and sparse vector are tested.
31+
*/
2632
class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
2733
import VectorRDDFunctionsSuite._
2834

@@ -33,13 +39,15 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
3339
)
3440

3541
val sparseData = ArrayBuffer(Vectors.sparse(20, Seq((0, 1.0), (9, 2.0), (10, 7.0))))
36-
for (i <- 0 to 10000) sparseData += Vectors.sparse(20, Seq((9, 0.0)))
42+
for (i <- 0 until 10000) sparseData += Vectors.sparse(20, Seq((9, 0.0)))
3743
sparseData += Vectors.sparse(20, Seq((0, 5.0), (9, 13.0), (16, 2.0)))
3844
sparseData += Vectors.sparse(20, Seq((3, 5.0), (9, 13.0), (18, 2.0)))
3945

4046
test("full-statistics") {
4147
val data = sc.parallelize(localData, 2)
42-
val (VectorRDDStatisticalSummary(mean, variance, cnt, nnz, max, min), denseTime) = time(data.summarizeStatistics(3))
48+
val (VectorRDDStatisticalSummary(mean, variance, cnt, nnz, max, min), denseTime) =
49+
time(data.summarizeStatistics(3))
50+
4351
assert(equivVector(mean, Vectors.dense(4.0, 5.0, 6.0)), "Column mean do not match.")
4452
assert(equivVector(variance, Vectors.dense(6.0, 6.0, 6.0)), "Column variance do not match.")
4553
assert(cnt === 3, "Column cnt do not match.")
@@ -48,21 +56,12 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
4856
assert(equivVector(min, Vectors.dense(1.0, 2.0, 3.0)), "Column min do not match.")
4957

5058
val dataForSparse = sc.parallelize(sparseData.toSeq, 2)
51-
val (VectorRDDStatisticalSummary(sparseMean, sparseVariance, sparseCnt, sparseNnz, sparseMax, sparseMin), sparseTime) = time(dataForSparse.summarizeStatistics(20))
52-
/*
53-
assert(equivVector(sparseMean, Vectors.dense(4.0, 5.0, 6.0)), "Column mean do not match.")
54-
assert(equivVector(sparseVariance, Vectors.dense(6.0, 6.0, 6.0)), "Column variance do not match.")
55-
assert(sparseCnt === 3, "Column cnt do not match.")
56-
assert(equivVector(sparseNnz, Vectors.dense(3.0, 3.0, 3.0)), "Column nnz do not match.")
57-
assert(equivVector(sparseMax, Vectors.dense(7.0, 8.0, 9.0)), "Column max do not match.")
58-
assert(equivVector(sparseMin, Vectors.dense(1.0, 2.0, 3.0)), "Column min do not match.")
59-
*/
60-
61-
59+
val (_, sparseTime) = time(dataForSparse.summarizeStatistics(20))
6260

6361
println(s"dense time is $denseTime, sparse time is $sparseTime.")
62+
assert(relativeTime(denseTime, sparseTime),
63+
"Relative time between dense and sparse vector doesn't match.")
6464
}
65-
6665
}
6766

6867
object VectorRDDFunctionsSuite {
@@ -76,5 +75,10 @@ object VectorRDDFunctionsSuite {
7675
def equivVector(lhs: Vector, rhs: Vector): Boolean = {
7776
(lhs.toBreeze - rhs.toBreeze).norm(2) < 1e-9
7877
}
78+
79+
def relativeTime(lhs: Double, rhs: Double): Boolean = {
80+
val denominator = math.max(lhs, rhs)
81+
math.abs(lhs - rhs) / denominator < 0.3
82+
}
7983
}
8084

0 commit comments

Comments
 (0)