Skip to content

Commit 036b7a5

Browse files
committed
fix the bug of Nan occur
1 parent f6e8e9a commit 036b7a5

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717
package org.apache.spark.mllib.rdd
1818

19-
import breeze.linalg.{Vector => BV}
19+
import breeze.linalg.{Vector => BV, axpy}
2020

2121
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2222
import org.apache.spark.rdd.RDD
@@ -92,8 +92,14 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
9292
VectorRDDStatisticalRing(mean2, m2n2, cnt2, nnz2, max2, min2)) =>
9393
val totalCnt = cnt1 + cnt2
9494
val deltaMean = mean2 - mean1
95-
val totalMean = ((mean1 :* nnz1) + (mean2 :* nnz2)) :/ (nnz1 + nnz2)
96-
val totalM2n = m2n1 + m2n2 + ((deltaMean :* deltaMean) :* (nnz1 :* nnz2) :/ (nnz1 + nnz2))
95+
mean2.activeIterator.foreach {
96+
case (id, 0.0) =>
97+
case (id, value) => mean1(id) = (mean1(id) * nnz1(id) + mean2(id) * nnz2(id)) / (nnz1(id) + nnz2(id))
98+
}
99+
m2n2.activeIterator.foreach {
100+
case (id, 0.0) =>
101+
case (id, value) => m2n1(id) += value + deltaMean(id) * deltaMean(id) * nnz1(id) * nnz2(id) / (nnz1(id)+nnz2(id))
102+
}
97103
max2.activeIterator.foreach {
98104
case (id, value) =>
99105
if (max1(id) < value) max1(id) = value
@@ -102,7 +108,8 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
102108
case (id, value) =>
103109
if (min1(id) > value) min1(id) = value
104110
}
105-
VectorRDDStatisticalRing(totalMean, totalM2n, totalCnt, nnz1 + nnz2, max1, min1)
111+
axpy(1.0, nnz2, nnz1)
112+
VectorRDDStatisticalRing(mean1, m2n1, totalCnt, nnz1, max1, min1)
106113
}
107114
}
108115

0 commit comments

Comments
 (0)