Skip to content

Commit 4cfbadf

Browse files
committed
fix bug of min max
1 parent 4e4fbd1 commit 4cfbadf

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import breeze.linalg.{Vector => BV}
2020

2121
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2222
import org.apache.spark.rdd.RDD
23-
import breeze.linalg.axpy
2423

2524
case class VectorRDDStatisticalSummary(
2625
mean: Vector,
@@ -35,8 +34,8 @@ private case class VectorRDDStatisticalRing(
3534
fakeM2n: BV[Double],
3635
totalCnt: Double,
3736
nnz: BV[Double],
38-
max: BV[Double],
39-
min: BV[Double])
37+
fakeMax: BV[Double],
38+
fakeMin: BV[Double])
4039

4140
/**
4241
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an
@@ -58,7 +57,9 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
5857
* with the size of Vector as input parameter.
5958
*/
6059

61-
private def seqOp(aggregator: VectorRDDStatisticalRing, currData: BV[Double]): VectorRDDStatisticalRing = {
60+
private def seqOp(
61+
aggregator: VectorRDDStatisticalRing,
62+
currData: BV[Double]): VectorRDDStatisticalRing = {
6263
aggregator match {
6364
case VectorRDDStatisticalRing(prevMean, prevM2n, cnt, nnzVec, maxVec, minVec) =>
6465
currData.activeIterator.foreach {
@@ -73,7 +74,8 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
7374
nnzVec(id) += 1.0
7475
}
7576

76-
VectorRDDStatisticalRing(prevMean,
77+
VectorRDDStatisticalRing(
78+
prevMean,
7779
prevM2n,
7880
cnt + 1.0,
7981
nnzVec,
@@ -82,7 +84,9 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
8284
}
8385
}
8486

85-
private def combOp(statistics1: VectorRDDStatisticalRing, statistics2: VectorRDDStatisticalRing): VectorRDDStatisticalRing = {
87+
private def combOp(
88+
statistics1: VectorRDDStatisticalRing,
89+
statistics2: VectorRDDStatisticalRing): VectorRDDStatisticalRing = {
8690
(statistics1, statistics2) match {
8791
case (VectorRDDStatisticalRing(mean1, m2n1, cnt1, nnz1, max1, min1),
8892
VectorRDDStatisticalRing(mean2, m2n2, cnt2, nnz2, max2, min2)) =>
@@ -111,26 +115,34 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
111115
BV.fill(size)(Double.MinValue),
112116
BV.fill(size)(Double.MaxValue))
113117

114-
val breezeVectors = self.collect().map(_.toBreeze)
115-
val VectorRDDStatisticalRing(fakeMean, fakeM2n, totalCnt, nnz, max, min) = breezeVectors.aggregate(zeroValue)(seqOp, combOp)
118+
val breezeVectors = self.map(_.toBreeze)
119+
val VectorRDDStatisticalRing(fakeMean, fakeM2n, totalCnt, nnz, fakeMax, fakeMin) =
120+
breezeVectors.aggregate(zeroValue)(seqOp, combOp)
116121

117122
// solve real mean
118123
val realMean = fakeMean :* nnz :/ totalCnt
119124
// solve real variance
120125
val deltaMean = fakeMean :- 0.0
121126
val realVar = fakeM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
122-
// max, min process, in case of a column is all zero.
123-
// max :+= 0.0
124-
// min :+= 0.0
125-
127+
// max, min
128+
val max = Vectors.sparse(size, fakeMax.activeIterator.map { case (id, value) =>
129+
if ((value == Double.MinValue) && (realMean(id) != Double.MinValue)) (id, 0.0)
130+
else (id, value)
131+
}.toSeq)
132+
val min = Vectors.sparse(size, fakeMin.activeIterator.map { case (id, value) =>
133+
if ((value == Double.MaxValue) && (realMean(id) != Double.MaxValue)) (id, 0.0)
134+
else (id, value)
135+
}.toSeq)
136+
137+
// get variance
126138
realVar :/= totalCnt
127139

128140
VectorRDDStatisticalSummary(
129141
Vectors.fromBreeze(realMean),
130142
Vectors.fromBreeze(realVar),
131143
totalCnt.toLong,
132144
Vectors.fromBreeze(nnz),
133-
Vectors.fromBreeze(max),
134-
Vectors.fromBreeze(min))
145+
max,
146+
min)
135147
}
136148
}

0 commit comments

Comments
 (0)