Skip to content

Commit fd2c032

Browse files
MechCodermengxr
authored andcommitted
[SPARK-5021] [MLlib] Gaussian Mixture now supports Sparse Input
Following discussion in the Jira. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #4459 from MechCoder/sparse_gmm and squashes the following commits: 1b18dab [MechCoder] Rewrite syr for sparse matrices e579041 [MechCoder] Add test for covariance matrix 5cb370b [MechCoder] Separate tests for sparse data 5e096bd [MechCoder] Alphabetize and correct error message e180f4c [MechCoder] [SPARK-5021] Gaussian Mixture now supports Sparse Input
1 parent f98707c commit fd2c032

File tree

5 files changed

+125
-26
lines changed

5 files changed

+125
-26
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ package org.apache.spark.mllib.clustering
1919

2020
import scala.collection.mutable.IndexedSeq
2121

22-
import breeze.linalg.{DenseMatrix => BreezeMatrix, DenseVector => BreezeVector, Transpose, diag}
22+
import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, SparseVector => BSV,
23+
Transpose, Vector => BV}
2324

2425
import org.apache.spark.annotation.Experimental
25-
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, Matrices, Vector, Vectors}
26+
import org.apache.spark.mllib.linalg.{BLAS, DenseVector, DenseMatrix, Matrices,
27+
SparseVector, Vector, Vectors}
2628
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
2729
import org.apache.spark.mllib.util.MLUtils
2830
import org.apache.spark.rdd.RDD
@@ -130,7 +132,7 @@ class GaussianMixture private (
130132
val sc = data.sparkContext
131133

132134
// we will operate on the data as breeze data
133-
val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()
135+
val breezeData = data.map(_.toBreeze).cache()
134136

135137
// Get length of the input vectors
136138
val d = breezeData.first().length
@@ -148,7 +150,7 @@ class GaussianMixture private (
148150
(Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
149151
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
150152
new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
151-
})
153+
})
152154
}
153155
}
154156

@@ -169,7 +171,7 @@ class GaussianMixture private (
169171
var i = 0
170172
while (i < k) {
171173
val mu = sums.means(i) / sums.weights(i)
172-
BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu).asInstanceOf[DenseVector],
174+
BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu),
173175
Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
174176
weights(i) = sums.weights(i) / sumWeights
175177
gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i))
@@ -185,8 +187,8 @@ class GaussianMixture private (
185187
}
186188

187189
/** Average of dense breeze vectors */
188-
private def vectorMean(x: IndexedSeq[BreezeVector[Double]]): BreezeVector[Double] = {
189-
val v = BreezeVector.zeros[Double](x(0).length)
190+
private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
191+
val v = BDV.zeros[Double](x(0).length)
190192
x.foreach(xi => v += xi)
191193
v / x.length.toDouble
192194
}
@@ -195,10 +197,10 @@ class GaussianMixture private (
195197
* Construct matrix where diagonal entries are element-wise
196198
* variance of input vectors (computes biased variance)
197199
*/
198-
private def initCovariance(x: IndexedSeq[BreezeVector[Double]]): BreezeMatrix[Double] = {
200+
private def initCovariance(x: IndexedSeq[BV[Double]]): BreezeMatrix[Double] = {
199201
val mu = vectorMean(x)
200-
val ss = BreezeVector.zeros[Double](x(0).length)
201-
x.map(xi => (xi - mu) :^ 2.0).foreach(u => ss += u)
202+
val ss = BDV.zeros[Double](x(0).length)
203+
x.foreach(xi => ss += (xi - mu) :^ 2.0)
202204
diag(ss / x.length.toDouble)
203205
}
204206
}
@@ -207,27 +209,26 @@ class GaussianMixture private (
207209
private object ExpectationSum {
208210
def zero(k: Int, d: Int): ExpectationSum = {
209211
new ExpectationSum(0.0, Array.fill(k)(0.0),
210-
Array.fill(k)(BreezeVector.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d)))
212+
Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d)))
211213
}
212214

213215
// compute cluster contributions for each input point
214216
// (U, T) => U for aggregation
215217
def add(
216218
weights: Array[Double],
217219
dists: Array[MultivariateGaussian])
218-
(sums: ExpectationSum, x: BreezeVector[Double]): ExpectationSum = {
220+
(sums: ExpectationSum, x: BV[Double]): ExpectationSum = {
219221
val p = weights.zip(dists).map {
220222
case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x)
221223
}
222224
val pSum = p.sum
223225
sums.logLikelihood += math.log(pSum)
224-
val xxt = x * new Transpose(x)
225226
var i = 0
226227
while (i < sums.k) {
227228
p(i) /= pSum
228229
sums.weights(i) += p(i)
229230
sums.means(i) += x * p(i)
230-
BLAS.syr(p(i), Vectors.fromBreeze(x).asInstanceOf[DenseVector],
231+
BLAS.syr(p(i), Vectors.fromBreeze(x),
231232
Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
232233
i = i + 1
233234
}
@@ -239,7 +240,7 @@ private object ExpectationSum {
239240
private class ExpectationSum(
240241
var logLikelihood: Double,
241242
val weights: Array[Double],
242-
val means: Array[BreezeVector[Double]],
243+
val means: Array[BDV[Double]],
243244
val sigmas: Array[BreezeMatrix[Double]]) extends Serializable {
244245

245246
val k = weights.length

mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,24 @@ private[spark] object BLAS extends Serializable with Logging {
235235
* @param x the vector x that contains the n elements.
236236
* @param A the symmetric matrix A. Size of n x n.
237237
*/
238-
def syr(alpha: Double, x: DenseVector, A: DenseMatrix) {
238+
def syr(alpha: Double, x: Vector, A: DenseMatrix) {
239239
val mA = A.numRows
240240
val nA = A.numCols
241-
require(mA == nA, s"A is not a symmetric matrix. A: $mA x $nA")
241+
require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA")
242242
require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}")
243243

244+
x match {
245+
case dv: DenseVector => syr(alpha, dv, A)
246+
case sv: SparseVector => syr(alpha, sv, A)
247+
case _ =>
248+
throw new IllegalArgumentException(s"syr doesn't support vector type ${x.getClass}.")
249+
}
250+
}
251+
252+
private def syr(alpha: Double, x: DenseVector, A: DenseMatrix) {
253+
val nA = A.numRows
254+
val mA = A.numCols
255+
244256
nativeBLAS.dsyr("U", x.size, alpha, x.values, 1, A.values, nA)
245257

246258
// Fill lower triangular part of A
@@ -255,6 +267,26 @@ private[spark] object BLAS extends Serializable with Logging {
255267
}
256268
}
257269

270+
private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) {
271+
val mA = A.numCols
272+
val xIndices = x.indices
273+
val xValues = x.values
274+
val nnz = xValues.length
275+
val Avalues = A.values
276+
277+
var i = 0
278+
while (i < nnz) {
279+
val multiplier = alpha * xValues(i)
280+
val offset = xIndices(i) * mA
281+
var j = 0
282+
while (j < nnz) {
283+
Avalues(xIndices(j) + offset) += multiplier * xValues(j)
284+
j += 1
285+
}
286+
i += 1
287+
}
288+
}
289+
258290
/**
259291
* C := alpha * A * B + beta * C
260292
* @param alpha a scalar to scale the multiplication A * B.

mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.mllib.stat.distribution
1919

20-
import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym}
20+
import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym, Vector => BV}
2121

2222
import org.apache.spark.annotation.DeveloperApi;
2323
import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
@@ -62,21 +62,21 @@ class MultivariateGaussian (
6262

6363
/** Returns density of this multivariate Gaussian at given point, x */
6464
def pdf(x: Vector): Double = {
65-
pdf(x.toBreeze.toDenseVector)
65+
pdf(x.toBreeze)
6666
}
6767

6868
/** Returns the log-density of this multivariate Gaussian at given point, x */
6969
def logpdf(x: Vector): Double = {
70-
logpdf(x.toBreeze.toDenseVector)
70+
logpdf(x.toBreeze)
7171
}
7272

7373
/** Returns density of this multivariate Gaussian at given point, x */
74-
private[mllib] def pdf(x: DBV[Double]): Double = {
74+
private[mllib] def pdf(x: BV[Double]): Double = {
7575
math.exp(logpdf(x))
7676
}
7777

7878
/** Returns the log-density of this multivariate Gaussian at given point, x */
79-
private[mllib] def logpdf(x: DBV[Double]): Double = {
79+
private[mllib] def logpdf(x: BV[Double]): Double = {
8080
val delta = x - breezeMu
8181
val v = rootSigmaInv * delta
8282
u + v.t * v * -0.5

mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
3131
Vectors.dense(5.0, 10.0),
3232
Vectors.dense(4.0, 11.0)
3333
))
34-
34+
3535
// expectations
3636
val Ew = 1.0
3737
val Emu = Vectors.dense(5.0, 10.0)
@@ -44,6 +44,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
4444
assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5)
4545
assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5)
4646
}
47+
4748
}
4849

4950
test("two clusters") {
@@ -54,7 +55,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
5455
Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
5556
Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
5657
))
57-
58+
5859
// we set an initial gaussian to induce expected results
5960
val initialGmm = new GaussianMixtureModel(
6061
Array(0.5, 0.5),
@@ -63,7 +64,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
6364
new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0)))
6465
)
6566
)
66-
67+
6768
val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
6869
val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604))
6970
val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644)))
@@ -72,12 +73,69 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
7273
.setK(2)
7374
.setInitialModel(initialGmm)
7475
.run(data)
75-
76+
7677
assert(gmm.weights(0) ~== Ew(0) absTol 1E-3)
7778
assert(gmm.weights(1) ~== Ew(1) absTol 1E-3)
7879
assert(gmm.gaussians(0).mu ~== Emu(0) absTol 1E-3)
7980
assert(gmm.gaussians(1).mu ~== Emu(1) absTol 1E-3)
8081
assert(gmm.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
8182
assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
8283
}
84+
85+
test("single cluster with sparse data") {
86+
val data = sc.parallelize(Array(
87+
Vectors.sparse(3, Array(0, 2), Array(4.0, 2.0)),
88+
Vectors.sparse(3, Array(0, 2), Array(2.0, 4.0)),
89+
Vectors.sparse(3, Array(1), Array(6.0))
90+
))
91+
92+
val Ew = 1.0
93+
val Emu = Vectors.dense(2.0, 2.0, 2.0)
94+
val Esigma = Matrices.dense(3, 3,
95+
Array(8.0 / 3.0, -4.0, 4.0 / 3.0, -4.0, 8.0, -4.0, 4.0 / 3.0, -4.0, 8.0 / 3.0)
96+
)
97+
98+
val seeds = Array(42, 1994, 27, 11, 0)
99+
seeds.foreach { seed =>
100+
val gmm = new GaussianMixture().setK(1).setSeed(seed).run(data)
101+
assert(gmm.weights(0) ~== Ew absTol 1E-5)
102+
assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5)
103+
assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5)
104+
}
105+
}
106+
107+
test("two clusters with sparse data") {
108+
val data = sc.parallelize(Array(
109+
Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
110+
Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
111+
Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
112+
Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
113+
Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
114+
))
115+
116+
val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray))
117+
// we set an initial gaussian to induce expected results
118+
val initialGmm = new GaussianMixtureModel(
119+
Array(0.5, 0.5),
120+
Array(
121+
new MultivariateGaussian(Vectors.dense(-1.0), Matrices.dense(1, 1, Array(1.0))),
122+
new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0)))
123+
)
124+
)
125+
val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
126+
val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604))
127+
val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644)))
128+
129+
val sparseGMM = new GaussianMixture()
130+
.setK(2)
131+
.setInitialModel(initialGmm)
132+
.run(data)
133+
134+
assert(sparseGMM.weights(0) ~== Ew(0) absTol 1E-3)
135+
assert(sparseGMM.weights(1) ~== Ew(1) absTol 1E-3)
136+
assert(sparseGMM.gaussians(0).mu ~== Emu(0) absTol 1E-3)
137+
assert(sparseGMM.gaussians(1).mu ~== Emu(1) absTol 1E-3)
138+
assert(sparseGMM.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
139+
assert(sparseGMM.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
140+
}
83141
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,14 @@ class BLASSuite extends FunSuite {
166166
syr(alpha, y, dA)
167167
}
168168
}
169+
170+
val xSparse = new SparseVector(4, Array(0, 2, 3), Array(1.0, 3.0, 4.0))
171+
val dD = new DenseMatrix(4, 4,
172+
Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8))
173+
syr(0.1, xSparse, dD)
174+
val expectedSparse = new DenseMatrix(4, 4,
175+
Array(0.1, 1.2, 2.5, 3.5, 1.2, 3.2, 5.3, 4.6, 2.5, 5.3, 2.7, 4.2, 3.5, 4.6, 4.2, 2.4))
176+
assert(dD ~== expectedSparse absTol 1e-15)
169177
}
170178

171179
test("gemm") {

0 commit comments

Comments
 (0)