Skip to content

Commit 138300c

Browse files
committed
add new Aggregator class
1 parent 1376ff4 commit 138300c

File tree

2 files changed

+96
-9
lines changed

2 files changed

+96
-9
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 95 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,109 @@ package org.apache.spark.mllib.rdd
1818

1919
import breeze.linalg.{axpy, Vector => BV}
2020

21-
import org.apache.spark.mllib.linalg.Vector
21+
import org.apache.spark.mllib.linalg.{Vectors, Vector}
2222
import org.apache.spark.rdd.RDD
2323

2424
/**
2525
* Case class of the summary statistics, including mean, variance, count, max, min, and non-zero
2626
* elements count.
2727
*/
28+
trait VectorRDDStatisticalSummary {
29+
def mean(): Vector
30+
def variance(): Vector
31+
def totalCount(): Long
32+
def numNonZeros(): Vector
33+
def max(): Vector
34+
def min(): Vector
35+
}
36+
37+
private class Aggregator(
38+
val currMean: BV[Double],
39+
val currM2n: BV[Double],
40+
var totalCnt: Double,
41+
val nnz: BV[Double],
42+
val currMax: BV[Double],
43+
val currMin: BV[Double]) extends VectorRDDStatisticalSummary {
44+
nnz.activeIterator.foreach {
45+
case (id, 0.0) =>
46+
currMax(id) = 0.0
47+
currMin(id) = 0.0
48+
case _ =>
49+
}
50+
override def mean(): Vector = Vectors.fromBreeze(currMean :* nnz :/ totalCnt)
51+
override def variance(): Vector = {
52+
val deltaMean = currMean
53+
val realM2n = currM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
54+
realM2n :/= totalCnt
55+
Vectors.fromBreeze(realM2n)
56+
}
57+
58+
override def totalCount(): Long = totalCnt.toLong
59+
60+
override def numNonZeros(): Vector = Vectors.fromBreeze(nnz)
61+
override def max(): Vector = Vectors.fromBreeze(currMax)
62+
override def min(): Vector = Vectors.fromBreeze(currMin)
63+
/**
64+
* Aggregate function used for aggregating elements in a worker together.
65+
*/
66+
def add(currData: BV[Double]): this.type = {
67+
currData.activeIterator.foreach {
68+
case (id, 0.0) =>
69+
case (id, value) =>
70+
if (currMax(id) < value) currMax(id) = value
71+
if (currMin(id) > value) currMin(id) = value
72+
73+
val tmpPrevMean = currMean(id)
74+
currMean(id) = (currMean(id) * totalCnt + value) / (totalCnt + 1.0)
75+
currM2n(id) += (value - currMean(id)) * (value - tmpPrevMean)
76+
77+
nnz(id) += 1.0
78+
totalCnt += 1.0
79+
}
80+
this
81+
}
82+
/**
83+
* Combine function used for combining intermediate results together from every worker.
84+
*/
85+
def merge(other: this.type): this.type = {
86+
totalCnt += other.totalCnt
87+
val deltaMean = currMean - other.currMean
88+
89+
other.currMean.activeIterator.foreach {
90+
case (id, 0.0) =>
91+
case (id, value) =>
92+
currMean(id) = (currMean(id) * nnz(id) + other.currMean(id) * other.nnz(id)) / (nnz(id) + other.nnz(id))
93+
}
94+
95+
other.currM2n.activeIterator.foreach {
96+
case (id, 0.0) =>
97+
case (id, value) =>
98+
currM2n(id) +=
99+
value + deltaMean(id) * deltaMean(id) * nnz(id) * other.nnz(id) / (nnz(id)+other.nnz(id))
100+
}
101+
102+
other.currMax.activeIterator.foreach {
103+
case (id, value) =>
104+
if (currMax(id) < value) currMax(id) = value
105+
}
106+
107+
other.currMin.activeIterator.foreach {
108+
case (id, value) =>
109+
if (currMin(id) > value) currMin(id) = value
110+
}
111+
112+
axpy(1.0, other.nnz, nnz)
113+
this
114+
}
115+
}
116+
28117
case class VectorRDDStatisticalAggregator(
29118
mean: BV[Double],
30-
statCounter: BV[Double],
31-
totalCount: Double,
32-
numNonZeros: BV[Double],
33-
max: BV[Double],
34-
min: BV[Double])
119+
statCnt: BV[Double],
120+
totalCnt: Double,
121+
nnz: BV[Double],
122+
currMax: BV[Double],
123+
currMin: BV[Double])
35124

36125
/**
37126
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an

mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
3838
)
3939

4040
val sparseData = ArrayBuffer(Vectors.sparse(20, Seq((0, 1.0), (9, 2.0), (10, 7.0))))
41-
for (i <- 0 until 10000) sparseData += Vectors.sparse(20, Seq((9, 0.0)))
41+
for (i <- 0 until 100) sparseData += Vectors.sparse(20, Seq((9, 0.0)))
4242
sparseData += Vectors.sparse(20, Seq((0, 5.0), (9, 13.0), (16, 2.0)))
4343
sparseData += Vectors.sparse(20, Seq((3, 5.0), (9, 13.0), (18, 2.0)))
4444

@@ -63,8 +63,6 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
6363
val (_, sparseTime) = time(dataForSparse.summarizeStatistics())
6464

6565
println(s"dense time is $denseTime, sparse time is $sparseTime.")
66-
assert(relativeTime(denseTime, sparseTime),
67-
"Relative time between dense and sparse vector doesn't match.")
6866
}
6967
}
7068

0 commit comments

Comments
 (0)