Skip to content

Commit e7d413b

Browse files
committed
Moved multivariate Gaussian utility class to mllib/stat/impl
Improved comments
1 parent 9770261 commit e7d413b

File tree

2 files changed

+21
-26
lines changed

2 files changed

+21
-26
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ import org.apache.spark.mllib.linalg.Matrix
2121
import org.apache.spark.mllib.linalg.Vector
2222

2323
/**
24-
* Multivariate Gaussian mixture model consisting of k Gaussians, where points are drawn
25-
* from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are the respective
26-
* mean and covariance for each Gaussian distribution i=1..k.
24+
* Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points
25+
* are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are
26+
* the respective mean and covariance for each Gaussian distribution i=1..k.
2727
*
2828
* @param weight Weights for each Gaussian distribution in the mixture, where mu(i) is
2929
* the weight for Gaussian i, and weight.sum == 1

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,30 @@
1818
package org.apache.spark.mllib.clustering
1919

2020
import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix}
21-
import breeze.linalg.{Transpose, det, inv}
21+
import breeze.linalg.Transpose
2222

2323
import org.apache.spark.rdd.RDD
2424
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors}
25+
import org.apache.spark.mllib.stat.impl.MultivariateGaussian
2526
import org.apache.spark.{Accumulator, AccumulatorParam, SparkContext}
2627
import org.apache.spark.SparkContext.DoubleAccumulatorParam
2728

2829
/**
29-
* This class performs multivariate Gaussian expectation maximization. It will
30-
* maximize the log-likelihood for a mixture of k Gaussians, iterating until
31-
* the log-likelihood changes by less than delta, or until it has reached
32-
* the max number of iterations.
30+
* This class performs expectation maximization for multivariate Gaussian
31+
* Mixture Models (GMMs). A GMM represents a composite distribution of
32+
* independent Gaussian distributions with associated "mixing" weights
33+
* specifying each's contribution to the composite.
34+
*
35+
* Given a set of sample points, this class will maximize the log-likelihood
36+
* for a mixture of k Gaussians, iterating until the log-likelihood changes by
37+
* less than convergenceTol, or until it has reached the max number of iterations.
38+
* While this process is generally guaranteed to converge, it is not guaranteed
39+
* to find a global optimum.
40+
*
41+
* @param k The number of independent Gaussians in the mixture model
42+
* @param convergenceTol The maximum change in log-likelihood at which convergence
43+
* is considered to have occurred.
44+
* @param maxIterations The maximum number of iterations to perform
3345
*/
3446
class GaussianMixtureModelEM private (
3547
private var k: Int,
@@ -40,7 +52,7 @@ class GaussianMixtureModelEM private (
4052
private type DenseDoubleVector = BreezeVector[Double]
4153
private type DenseDoubleMatrix = BreezeMatrix[Double]
4254

43-
/** number of samples per cluster to use when initializing Gaussians */
55+
// number of samples per cluster to use when initializing Gaussians
4456
private val nSamples = 5
4557

4658
/** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
@@ -219,21 +231,4 @@ class GaussianMixtureModelEM private (
219231
a += b
220232
}
221233
}
222-
223-
/**
224-
* Utility class to implement the density function for multivariate Gaussian distribution.
225-
* Breeze provides this functionality, but it requires the Apache Commons Math library,
226-
* so this class is here so-as to not introduce a new dependency in Spark.
227-
*/
228-
private class MultivariateGaussian(val mu: DenseDoubleVector, val sigma: DenseDoubleMatrix)
229-
extends Serializable {
230-
private val sigmaInv2 = inv(sigma) * -0.5
231-
private val U = math.pow(2.0 * math.Pi, -mu.length / 2.0) * math.pow(det(sigma), -0.5)
232-
233-
def pdf(x: DenseDoubleVector): Double = {
234-
val delta = x - mu
235-
val deltaTranspose = new Transpose(delta)
236-
U * math.exp(deltaTranspose * sigmaInv2 * delta)
237-
}
238-
}
239234
}

0 commit comments

Comments
 (0)