16
16
*/
17
17
package org .apache .spark .mllib .rdd
18
18
19
- import breeze .linalg .{Vector => BV }
19
+ import breeze .linalg .{Vector => BV , axpy }
20
20
21
21
import org .apache .spark .mllib .linalg .{Vector , Vectors }
22
22
import org .apache .spark .rdd .RDD
@@ -92,8 +92,14 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
92
92
VectorRDDStatisticalRing (mean2, m2n2, cnt2, nnz2, max2, min2)) =>
93
93
val totalCnt = cnt1 + cnt2
94
94
val deltaMean = mean2 - mean1
95
- val totalMean = ((mean1 :* nnz1) + (mean2 :* nnz2)) :/ (nnz1 + nnz2)
96
- val totalM2n = m2n1 + m2n2 + ((deltaMean :* deltaMean) :* (nnz1 :* nnz2) :/ (nnz1 + nnz2))
95
+ mean2.activeIterator.foreach {
96
+ case (id, 0.0 ) =>
97
+ case (id, value) => mean1(id) = (mean1(id) * nnz1(id) + mean2(id) * nnz2(id)) / (nnz1(id) + nnz2(id))
98
+ }
99
+ m2n2.activeIterator.foreach {
100
+ case (id, 0.0 ) =>
101
+ case (id, value) => m2n1(id) += value + deltaMean(id) * deltaMean(id) * nnz1(id) * nnz2(id) / (nnz1(id)+ nnz2(id))
102
+ }
97
103
max2.activeIterator.foreach {
98
104
case (id, value) =>
99
105
if (max1(id) < value) max1(id) = value
@@ -102,7 +108,8 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
102
108
case (id, value) =>
103
109
if (min1(id) > value) min1(id) = value
104
110
}
105
- VectorRDDStatisticalRing (totalMean, totalM2n, totalCnt, nnz1 + nnz2, max1, min1)
111
+ axpy(1.0 , nnz2, nnz1)
112
+ VectorRDDStatisticalRing (mean1, m2n1, totalCnt, nnz1, max1, min1)
106
113
}
107
114
}
108
115
0 commit comments