@@ -18,20 +18,109 @@ package org.apache.spark.mllib.rdd
18
18
19
19
import breeze .linalg .{axpy , Vector => BV }
20
20
21
- import org .apache .spark .mllib .linalg .Vector
21
+ import org .apache .spark .mllib .linalg .{ Vectors , Vector }
22
22
import org .apache .spark .rdd .RDD
23
23
24
24
/**
25
25
* Case class of the summary statistics, including mean, variance, count, max, min, and non-zero
26
26
* elements count.
27
27
*/
28
+ trait VectorRDDStatisticalSummary {
29
+ def mean (): Vector
30
+ def variance (): Vector
31
+ def totalCount (): Long
32
+ def numNonZeros (): Vector
33
+ def max (): Vector
34
+ def min (): Vector
35
+ }
36
+
37
+ private class Aggregator (
38
+ val currMean : BV [Double ],
39
+ val currM2n : BV [Double ],
40
+ var totalCnt : Double ,
41
+ val nnz : BV [Double ],
42
+ val currMax : BV [Double ],
43
+ val currMin : BV [Double ]) extends VectorRDDStatisticalSummary {
44
+ nnz.activeIterator.foreach {
45
+ case (id, 0.0 ) =>
46
+ currMax(id) = 0.0
47
+ currMin(id) = 0.0
48
+ case _ =>
49
+ }
50
+ override def mean (): Vector = Vectors .fromBreeze(currMean :* nnz :/ totalCnt)
51
+ override def variance (): Vector = {
52
+ val deltaMean = currMean
53
+ val realM2n = currM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
54
+ realM2n :/= totalCnt
55
+ Vectors .fromBreeze(realM2n)
56
+ }
57
+
58
+ override def totalCount (): Long = totalCnt.toLong
59
+
60
+ override def numNonZeros (): Vector = Vectors .fromBreeze(nnz)
61
+ override def max (): Vector = Vectors .fromBreeze(currMax)
62
+ override def min (): Vector = Vectors .fromBreeze(currMin)
63
+ /**
64
+ * Aggregate function used for aggregating elements in a worker together.
65
+ */
66
+ def add (currData : BV [Double ]): this .type = {
67
+ currData.activeIterator.foreach {
68
+ case (id, 0.0 ) =>
69
+ case (id, value) =>
70
+ if (currMax(id) < value) currMax(id) = value
71
+ if (currMin(id) > value) currMin(id) = value
72
+
73
+ val tmpPrevMean = currMean(id)
74
+ currMean(id) = (currMean(id) * totalCnt + value) / (totalCnt + 1.0 )
75
+ currM2n(id) += (value - currMean(id)) * (value - tmpPrevMean)
76
+
77
+ nnz(id) += 1.0
78
+ totalCnt += 1.0
79
+ }
80
+ this
81
+ }
82
+ /**
83
+ * Combine function used for combining intermediate results together from every worker.
84
+ */
85
+ def merge (other : this .type ): this .type = {
86
+ totalCnt += other.totalCnt
87
+ val deltaMean = currMean - other.currMean
88
+
89
+ other.currMean.activeIterator.foreach {
90
+ case (id, 0.0 ) =>
91
+ case (id, value) =>
92
+ currMean(id) = (currMean(id) * nnz(id) + other.currMean(id) * other.nnz(id)) / (nnz(id) + other.nnz(id))
93
+ }
94
+
95
+ other.currM2n.activeIterator.foreach {
96
+ case (id, 0.0 ) =>
97
+ case (id, value) =>
98
+ currM2n(id) +=
99
+ value + deltaMean(id) * deltaMean(id) * nnz(id) * other.nnz(id) / (nnz(id)+ other.nnz(id))
100
+ }
101
+
102
+ other.currMax.activeIterator.foreach {
103
+ case (id, value) =>
104
+ if (currMax(id) < value) currMax(id) = value
105
+ }
106
+
107
+ other.currMin.activeIterator.foreach {
108
+ case (id, value) =>
109
+ if (currMin(id) > value) currMin(id) = value
110
+ }
111
+
112
+ axpy(1.0 , other.nnz, nnz)
113
+ this
114
+ }
115
+ }
116
+
28
117
case class VectorRDDStatisticalAggregator (
29
118
mean : BV [Double ],
30
- statCounter : BV [Double ],
31
- totalCount : Double ,
32
- numNonZeros : BV [Double ],
33
- max : BV [Double ],
34
- min : BV [Double ])
119
+ statCnt : BV [Double ],
120
+ totalCnt : Double ,
121
+ nnz : BV [Double ],
122
+ currMax : BV [Double ],
123
+ currMin : BV [Double ])
35
124
36
125
/**
37
126
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector ]] through an
0 commit comments