Skip to content

Commit 42b2142

Browse files
committed
Added functionality to allow setting of GMM starting point.
Added two cluster test to testing suite.
1 parent 8b633f3 commit 42b2142

File tree

2 files changed

+72
-7
lines changed

2 files changed

+72
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala

+39-7
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,26 @@ class GaussianMixtureModelEM private (
5555
// number of samples per cluster to use when initializing Gaussians
5656
private val nSamples = 5
5757

58+
// an initializing GMM can be provided rather than using the
59+
// default random starting point
60+
private var initialGmm: Option[GaussianMixtureModel] = None
61+
5862
/** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
5963
def this() = this(2, 0.01, 100)
6064

65+
/** Set the initial GMM starting point, bypassing the random initialization */
66+
def setInitialGmm(gmm: GaussianMixtureModel): this.type = {
67+
if (gmm.k == k) {
68+
initialGmm = Some(gmm)
69+
} else {
70+
throw new IllegalArgumentException("initialing GMM has mismatched cluster count (gmm.k != k)")
71+
}
72+
this
73+
}
74+
75+
/** Return the user supplied initial GMM, if supplied */
76+
def getInitialiGmm: Option[GaussianMixtureModel] = initialGmm
77+
6178
/** Set the number of Gaussians in the mixture model. Default: 2 */
6279
def setK(k: Int): this.type = {
6380
this.k = k
@@ -103,20 +120,35 @@ class GaussianMixtureModelEM private (
103120
// Get length of the input vectors
104121
val d = breezeData.first.length
105122

106-
// For each Gaussian, we will initialize the mean as the average
107-
// of some random samples from the data
108-
val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)
109-
110-
// gaussians will be array of (weight, mean, covariance) tuples
123+
// gaussians will be array of (weight, mean, covariance) tuples.
124+
// If the user supplied an initial GMM, we use those values, otherwise
111125
// we start with uniform weights, a random mean from the data, and
112126
// diagonal covariance matrices using component variances
113127
// derived from the samples
114-
var gaussians = (0 until k).map{ i =>
128+
var gaussians = initialGmm match {
129+
case Some(gmm) => (0 until k).map{ i =>
130+
(gmm.weight(i), gmm.mu(i).toBreeze.toDenseVector, gmm.sigma(i).toBreeze.toDenseMatrix)
131+
}.toArray
132+
133+
case None => {
134+
// For each Gaussian, we will initialize the mean as the average
135+
// of some random samples from the data
136+
val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)
137+
138+
(0 until k).map{ i =>
139+
(1.0 / k,
140+
vectorMean(samples.slice(i * nSamples, (i + 1) * nSamples)),
141+
initCovariance(samples.slice(i * nSamples, (i + 1) * nSamples)))
142+
}.toArray
143+
}
144+
}
145+
146+
/*var gaussians = (0 until k).map{ i =>
115147
(1.0 / k,
116148
vectorMean(samples.slice(i * nSamples, (i + 1) * nSamples)),
117149
initCovariance(samples.slice(i * nSamples, (i + 1) * nSamples)))
118150
}.toArray
119-
151+
*/
120152
val accW = new Array[Accumulator[Double]](k)
121153
val accMu = new Array[Accumulator[DenseDoubleVector]](k)
122154
val accSigma = new Array[Accumulator[DenseDoubleMatrix]](k)

mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala

+33
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,37 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
4242
assert(gmm.mu(0) ~== Emu absTol 1E-5)
4343
assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
4444
}
45+
46+
test("two clusters") {
47+
val data = sc.parallelize(Array(
48+
Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
49+
Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
50+
Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
51+
Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
52+
Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
53+
))
54+
55+
// we set an initial gaussian to induce expected results
56+
val initialGmm = new GaussianMixtureModel(
57+
Array(0.5, 0.5),
58+
Array(Vectors.dense(-1.0), Vectors.dense(1.0)),
59+
Array(Matrices.dense(1, 1, Array(1.0)), Matrices.dense(1, 1, Array(1.0)))
60+
)
61+
62+
val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
63+
val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604))
64+
val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644)))
65+
66+
val gmm = new GaussianMixtureModelEM()
67+
.setK(2)
68+
.setInitialGmm(initialGmm)
69+
.run(data)
70+
71+
assert(gmm.weight(0) ~== Ew(0) absTol 1E-3)
72+
assert(gmm.weight(1) ~== Ew(1) absTol 1E-3)
73+
assert(gmm.mu(0) ~== Emu(0) absTol 1E-3)
74+
assert(gmm.mu(1) ~== Emu(1) absTol 1E-3)
75+
assert(gmm.sigma(0) ~== Esigma(0) absTol 1E-3)
76+
assert(gmm.sigma(1) ~== Esigma(1) absTol 1E-3)
77+
}
4578
}

0 commit comments

Comments
 (0)