Skip to content

Commit cc65810

Browse files
committed
add parallel mean and variance
1 parent 9af2e95 commit cc65810

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import breeze.linalg.{Vector => BV, DenseVector => BDV}
2121
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2222
import org.apache.spark.mllib.util.MLUtils._
2323
import org.apache.spark.rdd.RDD
24+
import breeze.numerics._
2425

2526
/**
2627
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an
@@ -161,4 +162,24 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
161162
}
162163
}
163164
}
165+
166+
def parallelMeanAndVar(size: Int): (Vector, Vector) = {
167+
val statistics = self.map(_.toBreeze).aggregate((BV.zeros[Double](size), BV.zeros[Double](size), 0.0))(
168+
seqOp = (c, v) => (c, v) match {
169+
case ((prevMean, prevM2n, cnt), currData) =>
170+
val currMean = ((prevMean :* cnt) + currData) :/ (cnt + 1.0)
171+
(currMean, prevM2n + ((currData - prevMean) :* (currData - currMean)), cnt + 1.0)
172+
},
173+
combOp = (lhs, rhs) => (lhs, rhs) match {
174+
case ((lhsMean, lhsM2n, lhsCnt), (rhsMean, rhsM2n, rhsCnt)) =>
175+
val totalCnt = lhsCnt + rhsCnt
176+
val totalMean = (lhsMean :* lhsCnt) + (rhsMean :* rhsCnt) :/ totalCnt
177+
val deltaMean = rhsMean - lhsMean
178+
val totalM2n = lhsM2n + rhsM2n + (((deltaMean :* deltaMean) :* (lhsCnt * rhsCnt)) :/ totalCnt)
179+
(totalMean, totalM2n, totalCnt)
180+
}
181+
)
182+
183+
(Vectors.fromBreeze(statistics._1), Vectors.fromBreeze(statistics._2 :/ statistics._3))
184+
}
164185
}

mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
3838
val colMeans = Array(4.0, 5.0, 6.0)
3939
val colNorm2 = Array(math.sqrt(66.0), math.sqrt(93.0), math.sqrt(126.0))
4040
val colSDs = Array(math.sqrt(6.0), math.sqrt(6.0), math.sqrt(6.0))
41+
val colVar = Array(6.0, 6.0, 6.0)
4142

4243
val maxVec = Array(7.0, 8.0, 9.0)
4344
val minVec = Array(1.0, 2.0, 3.0)
@@ -128,6 +129,13 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
128129
assert(equivVector(lhs, rhs), "Column shrink error.")
129130
}
130131
}
132+
133+
test("meanAndVar") {
134+
val data = sc.parallelize(localData, 2)
135+
val (mean, sd) = data.parallelMeanAndVar(3)
136+
assert(equivVector(mean, Vectors.dense(colMeans)), "Column means do not match.")
137+
assert(equivVector(sd, Vectors.dense(colVar)), "Column SD do not match.")
138+
}
131139
}
132140

133141
object VectorRDDFunctionsSuite {

0 commit comments

Comments
 (0)