Skip to content

Commit 23e2554

Browse files
tgaloppomengxr
authored andcommitted
SPARK-5019 [MLlib] - GaussianMixtureModel exposes instances of MultivariateGauss...
This PR modifies GaussianMixtureModel to expose instances of MutlivariateGaussian rather than separate mean and covariance arrays. Author: Travis Galoppo <tjg2107@columbia.edu> Closes #4088 from tgaloppo/spark-5019 and squashes the following commits: 3ef6c7f [Travis Galoppo] In GaussianMixtureModel: Changed name of weight, gaussian to weights, gaussians. Other sources modified accordingly. 091e8da [Travis Galoppo] SPARK-5019 - GaussianMixtureModel exposes instances of MultivariateGaussian rather than mean/covariance matrices
1 parent 769aced commit 23e2554

File tree

4 files changed

+26
-31
lines changed

4 files changed

+26
-31
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ object DenseGmmEM {
5454

5555
for (i <- 0 until clusters.k) {
5656
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
57-
(clusters.weight(i), clusters.mu(i), clusters.sigma(i)))
57+
(clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma))
5858
}
5959

6060
println("Cluster labels (first <= 100):")

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,7 @@ class GaussianMixtureEM private (
134134
// diagonal covariance matrices using component variances
135135
// derived from the samples
136136
val (weights, gaussians) = initialModel match {
137-
case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) =>
138-
new MultivariateGaussian(mu, sigma)
139-
})
137+
case Some(gmm) => (gmm.weights, gmm.gaussians)
140138

141139
case None => {
142140
val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
@@ -176,10 +174,7 @@ class GaussianMixtureEM private (
176174
iter += 1
177175
}
178176

179-
// Need to convert the breeze matrices to MLlib matrices
180-
val means = Array.tabulate(k) { i => gaussians(i).mu }
181-
val sigmas = Array.tabulate(k) { i => gaussians(i).sigma }
182-
new GaussianMixtureModel(weights, means, sigmas)
177+
new GaussianMixtureModel(weights, gaussians)
183178
}
184179

185180
/** Average of dense breeze vectors */

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering
2020
import breeze.linalg.{DenseVector => BreezeVector}
2121

2222
import org.apache.spark.rdd.RDD
23-
import org.apache.spark.mllib.linalg.{Matrix, Vector}
23+
import org.apache.spark.mllib.linalg.Vector
2424
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
2525
import org.apache.spark.mllib.util.MLUtils
2626

@@ -36,12 +36,13 @@ import org.apache.spark.mllib.util.MLUtils
3636
* covariance matrix for Gaussian i
3737
*/
3838
class GaussianMixtureModel(
39-
val weight: Array[Double],
40-
val mu: Array[Vector],
41-
val sigma: Array[Matrix]) extends Serializable {
39+
val weights: Array[Double],
40+
val gaussians: Array[MultivariateGaussian]) extends Serializable {
41+
42+
require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")
4243

4344
/** Number of gaussians in mixture */
44-
def k: Int = weight.length
45+
def k: Int = weights.length
4546

4647
/** Maps given points to their cluster indices. */
4748
def predict(points: RDD[Vector]): RDD[Int] = {
@@ -55,14 +56,10 @@ class GaussianMixtureModel(
5556
*/
5657
def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = {
5758
val sc = points.sparkContext
58-
val dists = sc.broadcast {
59-
(0 until k).map { i =>
60-
new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix)
61-
}.toArray
62-
}
63-
val weights = sc.broadcast(weight)
59+
val bcDists = sc.broadcast(gaussians)
60+
val bcWeights = sc.broadcast(weights)
6461
points.map { x =>
65-
computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k)
62+
computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k)
6663
}
6764
}
6865

mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
2020
import org.scalatest.FunSuite
2121

2222
import org.apache.spark.mllib.linalg.{Vectors, Matrices}
23+
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
2324
import org.apache.spark.mllib.util.MLlibTestSparkContext
2425
import org.apache.spark.mllib.util.TestingUtils._
2526

@@ -39,9 +40,9 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
3940
val seeds = Array(314589, 29032897, 50181, 494821, 4660)
4041
seeds.foreach { seed =>
4142
val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data)
42-
assert(gmm.weight(0) ~== Ew absTol 1E-5)
43-
assert(gmm.mu(0) ~== Emu absTol 1E-5)
44-
assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
43+
assert(gmm.weights(0) ~== Ew absTol 1E-5)
44+
assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5)
45+
assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5)
4546
}
4647
}
4748

@@ -57,8 +58,10 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
5758
// we set an initial gaussian to induce expected results
5859
val initialGmm = new GaussianMixtureModel(
5960
Array(0.5, 0.5),
60-
Array(Vectors.dense(-1.0), Vectors.dense(1.0)),
61-
Array(Matrices.dense(1, 1, Array(1.0)), Matrices.dense(1, 1, Array(1.0)))
61+
Array(
62+
new MultivariateGaussian(Vectors.dense(-1.0), Matrices.dense(1, 1, Array(1.0))),
63+
new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0)))
64+
)
6265
)
6366

6467
val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
@@ -70,11 +73,11 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
7073
.setInitialModel(initialGmm)
7174
.run(data)
7275

73-
assert(gmm.weight(0) ~== Ew(0) absTol 1E-3)
74-
assert(gmm.weight(1) ~== Ew(1) absTol 1E-3)
75-
assert(gmm.mu(0) ~== Emu(0) absTol 1E-3)
76-
assert(gmm.mu(1) ~== Emu(1) absTol 1E-3)
77-
assert(gmm.sigma(0) ~== Esigma(0) absTol 1E-3)
78-
assert(gmm.sigma(1) ~== Esigma(1) absTol 1E-3)
76+
assert(gmm.weights(0) ~== Ew(0) absTol 1E-3)
77+
assert(gmm.weights(1) ~== Ew(1) absTol 1E-3)
78+
assert(gmm.gaussians(0).mu ~== Emu(0) absTol 1E-3)
79+
assert(gmm.gaussians(1).mu ~== Emu(1) absTol 1E-3)
80+
assert(gmm.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
81+
assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
7982
}
8083
}

0 commit comments

Comments
 (0)