Skip to content

Commit 62a2c3e

Browse files
committed
use axpy and in-place if possible
1 parent 9a75ebd commit 62a2c3e

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import breeze.linalg.{Vector => BV}
2020

2121
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2222
import org.apache.spark.rdd.RDD
23+
import breeze.linalg.axpy
2324

2425
case class VectorRDDStatisticalSummary(
2526
mean: Vector,
@@ -58,17 +59,22 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
5859
BV.fill(size){Double.MaxValue}))(
5960
seqOp = (c, v) => (c, v) match {
6061
case ((prevMean, prevM2n, cnt, nnzVec, maxVec, minVec), currData) =>
61-
val currMean = ((prevMean :* cnt) + currData) :/ (cnt + 1.0)
62-
val nonZeroCnt = Vectors
63-
.sparse(size, currData.activeKeysIterator.toSeq.map(x => (x, 1.0))).toBreeze
62+
val currMean = prevMean :* (cnt / (cnt + 1.0))
63+
axpy(1.0/(cnt+1.0), currData, currMean)
64+
axpy(-1.0, currData, prevMean)
65+
prevMean :*= (currMean - currData)
66+
axpy(1.0, prevMean, prevM2n)
67+
axpy(1.0,
68+
Vectors.sparse(size, currData.activeKeysIterator.toSeq.map(x => (x, 1.0))).toBreeze,
69+
nnzVec)
6470
currData.activeIterator.foreach { case (id, value) =>
6571
if (maxVec(id) < value) maxVec(id) = value
6672
if (minVec(id) > value) minVec(id) = value
6773
}
6874
(currMean,
69-
prevM2n + ((currData - prevMean) :* (currData - currMean)),
75+
prevM2n,
7076
cnt + 1.0,
71-
nnzVec + nonZeroCnt,
77+
nnzVec,
7278
maxVec,
7379
minVec)
7480
},
@@ -77,23 +83,30 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
7783
(lhsMean, lhsM2n, lhsCnt, lhsNNZ, lhsMax, lhsMin),
7884
(rhsMean, rhsM2n, rhsCnt, rhsNNZ, rhsMax, rhsMin)) =>
7985
val totalCnt = lhsCnt + rhsCnt
80-
val totalMean = (lhsMean :* lhsCnt) + (rhsMean :* rhsCnt) :/ totalCnt
8186
val deltaMean = rhsMean - lhsMean
82-
val totalM2n =
83-
lhsM2n + rhsM2n + (((deltaMean :* deltaMean) :* (lhsCnt * rhsCnt)) :/ totalCnt)
87+
lhsMean :*= (lhsCnt / totalCnt)
88+
axpy(rhsCnt/totalCnt, rhsMean, lhsMean)
89+
val totalMean = lhsMean
90+
deltaMean :*= deltaMean
91+
axpy(lhsCnt*rhsCnt/totalCnt, deltaMean, lhsM2n)
92+
axpy(1.0, rhsM2n, lhsM2n)
93+
val totalM2n = lhsM2n
8494
rhsMax.activeIterator.foreach { case (id, value) =>
8595
if (lhsMax(id) < value) lhsMax(id) = value
8696
}
8797
rhsMin.activeIterator.foreach { case (id, value) =>
8898
if (lhsMin(id) > value) lhsMin(id) = value
8999
}
90-
(totalMean, totalM2n, totalCnt, lhsNNZ + rhsNNZ, lhsMax, lhsMin)
100+
axpy(1.0, rhsNNZ, lhsNNZ)
101+
(totalMean, totalM2n, totalCnt, lhsNNZ, lhsMax, lhsMin)
91102
}
92103
)
93104

105+
results._2 :/= results._3
106+
94107
VectorRDDStatisticalSummary(
95108
Vectors.fromBreeze(results._1),
96-
Vectors.fromBreeze(results._2 :/ results._3),
109+
Vectors.fromBreeze(results._2),
97110
results._3.toLong,
98111
Vectors.fromBreeze(results._4),
99112
Vectors.fromBreeze(results._5),

0 commit comments

Comments
 (0)